From 12c4bb5d99710e437e4e0e50a65991df4115566d Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Wed, 27 Nov 2024 20:01:08 +0100 Subject: [PATCH 1/5] chore: track usage of built-in example templates (#15671) Addresses https://github.com/coder/nexus/issues/99. Changes: - Save the id of the built-in example template used to create a template version in the database - Include the example id in telemetry (cherry picked from commit b830c05e3e10c82a1e90c59d340690d911613855) --- coderd/database/dbgen/dbgen.go | 21 ++++---- coderd/database/dbmem/dbmem.go | 21 ++++---- coderd/database/dump.sql | 4 +- ...0277_template_version_example_ids.down.sql | 28 +++++++++++ ...000277_template_version_example_ids.up.sql | 30 +++++++++++ coderd/database/models.go | 6 ++- coderd/database/queries.sql.go | 50 +++++++++++-------- coderd/database/queries/templateversions.sql | 5 +- coderd/telemetry/telemetry.go | 14 ++++-- coderd/telemetry/telemetry_test.go | 18 ++++++- coderd/templateversions.go | 4 ++ coderd/templateversions_test.go | 17 ++++++- docs/admin/security/audit-logs.md | 2 +- enterprise/audit/table.go | 1 + 14 files changed, 167 insertions(+), 54 deletions(-) create mode 100644 coderd/database/migrations/000277_template_version_example_ids.down.sql create mode 100644 coderd/database/migrations/000277_template_version_example_ids.up.sql diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index ae898d4f1fdc3..9c8696112dea8 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -788,16 +788,17 @@ func TemplateVersion(t testing.TB, db database.Store, orig database.TemplateVers err := db.InTx(func(db database.Store) error { versionID := takeFirst(orig.ID, uuid.New()) err := db.InsertTemplateVersion(genCtx, database.InsertTemplateVersionParams{ - ID: versionID, - TemplateID: takeFirst(orig.TemplateID, uuid.NullUUID{}), - OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()), - Name: takeFirst(orig.Name, testutil.GetRandomName(t)), - Message: orig.Message, - Readme: takeFirst(orig.Readme, testutil.GetRandomName(t)), - JobID: takeFirst(orig.JobID, uuid.New()), - CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + ID: versionID, + TemplateID: takeFirst(orig.TemplateID, uuid.NullUUID{}), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()), + Name: takeFirst(orig.Name, testutil.GetRandomName(t)), + Message: orig.Message, + Readme: takeFirst(orig.Readme, testutil.GetRandomName(t)), + JobID: takeFirst(orig.JobID, uuid.New()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + SourceExampleID: takeFirst(orig.SourceExampleID, sql.NullString{}), }) if err != nil { return err diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 5583fff111648..765573b311a84 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7699,16 +7699,17 @@ func (q *FakeQuerier) InsertTemplateVersion(_ context.Context, arg database.Inse //nolint:gosimple version := database.TemplateVersionTable{ - ID: arg.ID, - TemplateID: arg.TemplateID, - OrganizationID: arg.OrganizationID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Name: arg.Name, - Message: arg.Message, - Readme: arg.Readme, - JobID: arg.JobID, - CreatedBy: arg.CreatedBy, + ID: arg.ID, + TemplateID: arg.TemplateID, + OrganizationID: arg.OrganizationID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Name: arg.Name, + Message: arg.Message, + Readme: arg.Readme, + JobID: arg.JobID, + CreatedBy: arg.CreatedBy, + SourceExampleID: arg.SourceExampleID, } q.templateVersions = append(q.templateVersions, version) return nil diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 9919011579bde..eba9b7cf106d3 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1217,7 +1217,8 @@ CREATE TABLE template_versions ( created_by uuid NOT NULL, external_auth_providers jsonb DEFAULT '[]'::jsonb NOT NULL, message character varying(1048576) DEFAULT ''::character varying NOT NULL, - archived boolean DEFAULT false NOT NULL + archived boolean DEFAULT false NOT NULL, + source_example_id text ); COMMENT ON COLUMN template_versions.external_auth_providers IS 'IDs of External auth providers for a specific template version'; @@ -1245,6 +1246,7 @@ CREATE VIEW template_version_with_user AS template_versions.external_auth_providers, template_versions.message, template_versions.archived, + template_versions.source_example_id, COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url, COALESCE(visible_users.username, ''::text) AS created_by_username FROM (template_versions diff --git a/coderd/database/migrations/000277_template_version_example_ids.down.sql b/coderd/database/migrations/000277_template_version_example_ids.down.sql new file mode 100644 index 0000000000000..ad961e9f635c7 --- /dev/null +++ b/coderd/database/migrations/000277_template_version_example_ids.down.sql @@ -0,0 +1,28 @@ +-- We cannot alter the column type while a view depends on it, so we drop it and recreate it. +DROP VIEW template_version_with_user; + +ALTER TABLE + template_versions +DROP COLUMN source_example_id; + +-- Recreate `template_version_with_user` as described in dump.sql +CREATE VIEW template_version_with_user AS +SELECT + template_versions.id, + template_versions.template_id, + template_versions.organization_id, + template_versions.created_at, + template_versions.updated_at, + template_versions.name, + template_versions.readme, + template_versions.job_id, + template_versions.created_by, + template_versions.external_auth_providers, + template_versions.message, + template_versions.archived, + COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url, + COALESCE(visible_users.username, ''::text) AS created_by_username +FROM (template_versions + LEFT JOIN visible_users ON (template_versions.created_by = visible_users.id)); + +COMMENT ON VIEW template_version_with_user IS 'Joins in the username + avatar url of the created by user.'; diff --git a/coderd/database/migrations/000277_template_version_example_ids.up.sql b/coderd/database/migrations/000277_template_version_example_ids.up.sql new file mode 100644 index 0000000000000..aca34b31de5dc --- /dev/null +++ b/coderd/database/migrations/000277_template_version_example_ids.up.sql @@ -0,0 +1,30 @@ +-- We cannot alter the column type while a view depends on it, so we drop it and recreate it. +DROP VIEW template_version_with_user; + +ALTER TABLE + template_versions +ADD + COLUMN source_example_id TEXT; + +-- Recreate `template_version_with_user` as described in dump.sql +CREATE VIEW template_version_with_user AS +SELECT + template_versions.id, + template_versions.template_id, + template_versions.organization_id, + template_versions.created_at, + template_versions.updated_at, + template_versions.name, + template_versions.readme, + template_versions.job_id, + template_versions.created_by, + template_versions.external_auth_providers, + template_versions.message, + template_versions.archived, + template_versions.source_example_id, + COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url, + COALESCE(visible_users.username, ''::text) AS created_by_username +FROM (template_versions + LEFT JOIN visible_users ON (template_versions.created_by = visible_users.id)); + +COMMENT ON VIEW template_version_with_user IS 'Joins in the username + avatar url of the created by user.'; diff --git a/coderd/database/models.go b/coderd/database/models.go index af0a3122f7964..6b99245079950 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -2773,6 +2773,7 @@ type TemplateVersion struct { ExternalAuthProviders json.RawMessage `db:"external_auth_providers" json:"external_auth_providers"` Message string `db:"message" json:"message"` Archived bool `db:"archived" json:"archived"` + SourceExampleID sql.NullString `db:"source_example_id" json:"source_example_id"` CreatedByAvatarURL string `db:"created_by_avatar_url" json:"created_by_avatar_url"` CreatedByUsername string `db:"created_by_username" json:"created_by_username"` } @@ -2826,8 +2827,9 @@ type TemplateVersionTable struct { // IDs of External auth providers for a specific template version ExternalAuthProviders json.RawMessage `db:"external_auth_providers" json:"external_auth_providers"` // Message describing the changes in this version of the template, similar to a Git commit message. Like a commit message, this should be a short, high-level description of the changes in this version of the template. This message is immutable and should not be updated after the fact. - Message string `db:"message" json:"message"` - Archived bool `db:"archived" json:"archived"` + Message string `db:"message" json:"message"` + Archived bool `db:"archived" json:"archived"` + SourceExampleID sql.NullString `db:"source_example_id" json:"source_example_id"` } type TemplateVersionVariable struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 4eec78cf97fba..33a3ce12a444d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -8996,7 +8996,7 @@ FROM -- Scope an archive to a single template and ignore already archived template versions ( SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id FROM template_versions WHERE @@ -9097,7 +9097,7 @@ func (q *sqlQuerier) ArchiveUnusedTemplateVersions(ctx context.Context, arg Arch const getPreviousTemplateVersion = `-- name: GetPreviousTemplateVersion :one SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9134,6 +9134,7 @@ func (q *sqlQuerier) GetPreviousTemplateVersion(ctx context.Context, arg GetPrev &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ) @@ -9142,7 +9143,7 @@ func (q *sqlQuerier) GetPreviousTemplateVersion(ctx context.Context, arg GetPrev const getTemplateVersionByID = `-- name: GetTemplateVersionByID :one SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9165,6 +9166,7 @@ func (q *sqlQuerier) GetTemplateVersionByID(ctx context.Context, id uuid.UUID) ( &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ) @@ -9173,7 +9175,7 @@ func (q *sqlQuerier) GetTemplateVersionByID(ctx context.Context, id uuid.UUID) ( const getTemplateVersionByJobID = `-- name: GetTemplateVersionByJobID :one SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9196,6 +9198,7 @@ func (q *sqlQuerier) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.U &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ) @@ -9204,7 +9207,7 @@ func (q *sqlQuerier) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.U const getTemplateVersionByTemplateIDAndName = `-- name: GetTemplateVersionByTemplateIDAndName :one SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9233,6 +9236,7 @@ func (q *sqlQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ) @@ -9241,7 +9245,7 @@ func (q *sqlQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, const getTemplateVersionsByIDs = `-- name: GetTemplateVersionsByIDs :many SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9270,6 +9274,7 @@ func (q *sqlQuerier) GetTemplateVersionsByIDs(ctx context.Context, ids []uuid.UU &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ); err != nil { @@ -9288,7 +9293,7 @@ func (q *sqlQuerier) GetTemplateVersionsByIDs(ctx context.Context, ids []uuid.UU const getTemplateVersionsByTemplateID = `-- name: GetTemplateVersionsByTemplateID :many SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9364,6 +9369,7 @@ func (q *sqlQuerier) GetTemplateVersionsByTemplateID(ctx context.Context, arg Ge &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ); err != nil { @@ -9381,7 +9387,7 @@ func (q *sqlQuerier) GetTemplateVersionsByTemplateID(ctx context.Context, arg Ge } const getTemplateVersionsCreatedAfter = `-- name: GetTemplateVersionsCreatedAfter :many -SELECT id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE created_at > $1 +SELECT id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE created_at > $1 ` func (q *sqlQuerier) GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error) { @@ -9406,6 +9412,7 @@ func (q *sqlQuerier) GetTemplateVersionsCreatedAfter(ctx context.Context, create &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ); err != nil { @@ -9434,23 +9441,25 @@ INSERT INTO message, readme, job_id, - created_by + created_by, + source_example_id ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ` type InsertTemplateVersionParams struct { - ID uuid.UUID `db:"id" json:"id"` - TemplateID uuid.NullUUID `db:"template_id" json:"template_id"` - OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - Name string `db:"name" json:"name"` - Message string `db:"message" json:"message"` - Readme string `db:"readme" json:"readme"` - JobID uuid.UUID `db:"job_id" json:"job_id"` - CreatedBy uuid.UUID `db:"created_by" json:"created_by"` + ID uuid.UUID `db:"id" json:"id"` + TemplateID uuid.NullUUID `db:"template_id" json:"template_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Name string `db:"name" json:"name"` + Message string `db:"message" json:"message"` + Readme string `db:"readme" json:"readme"` + JobID uuid.UUID `db:"job_id" json:"job_id"` + CreatedBy uuid.UUID `db:"created_by" json:"created_by"` + SourceExampleID sql.NullString `db:"source_example_id" json:"source_example_id"` } func (q *sqlQuerier) InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) error { @@ -9465,6 +9474,7 @@ func (q *sqlQuerier) InsertTemplateVersion(ctx context.Context, arg InsertTempla arg.Readme, arg.JobID, arg.CreatedBy, + arg.SourceExampleID, ) return err } diff --git a/coderd/database/queries/templateversions.sql b/coderd/database/queries/templateversions.sql index 094c1b6014de7..0436a7f9ba3b9 100644 --- a/coderd/database/queries/templateversions.sql +++ b/coderd/database/queries/templateversions.sql @@ -87,10 +87,11 @@ INSERT INTO message, readme, job_id, - created_by + created_by, + source_example_id ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10); + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); -- name: UpdateTemplateVersionByID :exec UPDATE diff --git a/coderd/telemetry/telemetry.go b/coderd/telemetry/telemetry.go index 8ad85b0b39982..233450c43d943 100644 --- a/coderd/telemetry/telemetry.go +++ b/coderd/telemetry/telemetry.go @@ -868,6 +868,9 @@ func ConvertTemplateVersion(version database.TemplateVersion) TemplateVersion { if version.TemplateID.Valid { snapVersion.TemplateID = &version.TemplateID.UUID } + if version.SourceExampleID.Valid { + snapVersion.SourceExampleID = &version.SourceExampleID.String + } return snapVersion } @@ -1116,11 +1119,12 @@ type Template struct { } type TemplateVersion struct { - ID uuid.UUID `json:"id"` - CreatedAt time.Time `json:"created_at"` - TemplateID *uuid.UUID `json:"template_id,omitempty"` - OrganizationID uuid.UUID `json:"organization_id"` - JobID uuid.UUID `json:"job_id"` + ID uuid.UUID `json:"id"` + CreatedAt time.Time `json:"created_at"` + TemplateID *uuid.UUID `json:"template_id,omitempty"` + OrganizationID uuid.UUID `json:"organization_id"` + JobID uuid.UUID `json:"job_id"` + SourceExampleID *string `json:"source_example_id,omitempty"` } type ProvisionerJob struct { diff --git a/coderd/telemetry/telemetry_test.go b/coderd/telemetry/telemetry_test.go index 214d111a170a1..2b70cd2a6d2c3 100644 --- a/coderd/telemetry/telemetry_test.go +++ b/coderd/telemetry/telemetry_test.go @@ -1,6 +1,7 @@ package telemetry_test import ( + "database/sql" "encoding/json" "net/http" "net/http/httptest" @@ -48,6 +49,10 @@ func TestTelemetry(t *testing.T) { _ = dbgen.Template(t, db, database.Template{ Provisioner: database.ProvisionerTypeTerraform, }) + sourceExampleID := uuid.NewString() + _ = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + SourceExampleID: sql.NullString{String: sourceExampleID, Valid: true}, + }) _ = dbgen.TemplateVersion(t, db, database.TemplateVersion{}) user := dbgen.User(t, db, database.User{}) _ = dbgen.Workspace(t, db, database.WorkspaceTable{}) @@ -93,7 +98,7 @@ func TestTelemetry(t *testing.T) { require.Len(t, snapshot.ProvisionerJobs, 1) require.Len(t, snapshot.Licenses, 1) require.Len(t, snapshot.Templates, 1) - require.Len(t, snapshot.TemplateVersions, 1) + require.Len(t, snapshot.TemplateVersions, 2) require.Len(t, snapshot.Users, 1) require.Len(t, snapshot.Groups, 2) // 1 member in the everyone group + 1 member in the custom group @@ -111,6 +116,17 @@ func TestTelemetry(t *testing.T) { require.Len(t, wsa.Subsystems, 2) require.Equal(t, string(database.WorkspaceAgentSubsystemEnvbox), wsa.Subsystems[0]) require.Equal(t, string(database.WorkspaceAgentSubsystemExectrace), wsa.Subsystems[1]) + + tvs := snapshot.TemplateVersions + sort.Slice(tvs, func(i, j int) bool { + // Sort by SourceExampleID presence (non-nil comes before nil) + if (tvs[i].SourceExampleID != nil) != (tvs[j].SourceExampleID != nil) { + return tvs[i].SourceExampleID != nil + } + return false + }) + require.Equal(t, tvs[0].SourceExampleID, &sourceExampleID) + require.Nil(t, tvs[1].SourceExampleID) }) t.Run("HashedEmail", func(t *testing.T) { t.Parallel() diff --git a/coderd/templateversions.go b/coderd/templateversions.go index a0609c42c33f9..12def3e5d681b 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -1582,6 +1582,10 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht Readme: "", JobID: provisionerJob.ID, CreatedBy: apiKey.UserID, + SourceExampleID: sql.NullString{ + String: req.ExampleID, + Valid: req.ExampleID != "", + }, }) if err != nil { if database.IsUniqueViolation(err, database.UniqueTemplateVersionsTemplateIDNameKey) { diff --git a/coderd/templateversions_test.go b/coderd/templateversions_test.go index 5ebbd0f41804f..5e96de10d5058 100644 --- a/coderd/templateversions_test.go +++ b/coderd/templateversions_test.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/rbac" @@ -134,7 +135,7 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) { t.Run("WithParameters", func(t *testing.T) { t.Parallel() auditor := audit.NewMock() - client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true, Auditor: auditor}) + client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true, Auditor: auditor}) user := coderdtest.CreateFirstUser(t, client) data, err := echo.Tar(&echo.Responses{ Parse: echo.ParseComplete, @@ -160,11 +161,17 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) { require.Len(t, auditor.AuditLogs(), 2) assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[1].Action) + + admin, err := client.User(ctx, user.UserID.String()) + require.NoError(t, err) + tvDB, err := db.GetTemplateVersionByID(dbauthz.As(ctx, coderdtest.AuthzUserSubject(admin, user.OrganizationID)), version.ID) + require.NoError(t, err) + require.False(t, tvDB.SourceExampleID.Valid) }) t.Run("Example", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, nil) + client, db := coderdtest.NewWithDatabase(t, nil) user := coderdtest.CreateFirstUser(t, client) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) @@ -205,6 +212,12 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) { require.NoError(t, err) require.Equal(t, "my-example", tv.Name) + admin, err := client.User(ctx, user.UserID.String()) + require.NoError(t, err) + tvDB, err := db.GetTemplateVersionByID(dbauthz.As(ctx, coderdtest.AuthzUserSubject(admin, user.OrganizationID)), tv.ID) + require.NoError(t, err) + require.Equal(t, ls[0].ID, tvDB.SourceExampleID.String) + // ensure the template tar was uploaded correctly fl, ct, err := client.Download(ctx, tv.Job.FileID) require.NoError(t, err) diff --git a/docs/admin/security/audit-logs.md b/docs/admin/security/audit-logs.md index 3ea4e145d13eb..db214b0e1443e 100644 --- a/docs/admin/security/audit-logs.md +++ b/docs/admin/security/audit-logs.md @@ -24,7 +24,7 @@ We track the following resources: | OAuth2ProviderAppSecret
|
FieldTracked
app_idfalse
created_atfalse
display_secretfalse
hashed_secretfalse
idfalse
last_used_atfalse
secret_prefixfalse
| | Organization
|
FieldTracked
created_atfalse
descriptiontrue
display_nametrue
icontrue
idfalse
is_defaulttrue
nametrue
updated_attrue
| | Template
write, delete |
FieldTracked
active_version_idtrue
activity_bumptrue
allow_user_autostarttrue
allow_user_autostoptrue
allow_user_cancel_workspace_jobstrue
autostart_block_days_of_weektrue
autostop_requirement_days_of_weektrue
autostop_requirement_weekstrue
created_atfalse
created_bytrue
created_by_avatar_urlfalse
created_by_usernamefalse
default_ttltrue
deletedfalse
deprecatedtrue
descriptiontrue
display_nametrue
failure_ttltrue
group_acltrue
icontrue
idtrue
max_port_sharing_leveltrue
nametrue
organization_display_namefalse
organization_iconfalse
organization_idfalse
organization_namefalse
provisionertrue
require_active_versiontrue
time_til_dormanttrue
time_til_dormant_autodeletetrue
updated_atfalse
user_acltrue
| -| TemplateVersion
create, write |
FieldTracked
archivedtrue
created_atfalse
created_bytrue
created_by_avatar_urlfalse
created_by_usernamefalse
external_auth_providersfalse
idtrue
job_idfalse
messagefalse
nametrue
organization_idfalse
readmetrue
template_idtrue
updated_atfalse
| +| TemplateVersion
create, write |
FieldTracked
archivedtrue
created_atfalse
created_bytrue
created_by_avatar_urlfalse
created_by_usernamefalse
external_auth_providersfalse
idtrue
job_idfalse
messagefalse
nametrue
organization_idfalse
readmetrue
source_example_idfalse
template_idtrue
updated_atfalse
| | User
create, write, delete |
FieldTracked
avatar_urlfalse
created_atfalse
deletedtrue
emailtrue
github_com_user_idfalse
hashed_one_time_passcodefalse
hashed_passwordtrue
idtrue
last_seen_atfalse
login_typetrue
nametrue
one_time_passcode_expires_attrue
quiet_hours_scheduletrue
rbac_rolestrue
statustrue
theme_preferencefalse
updated_atfalse
usernametrue
| | WorkspaceBuild
start, stop |
FieldTracked
build_numberfalse
created_atfalse
daily_costfalse
deadlinefalse
idfalse
initiator_by_avatar_urlfalse
initiator_by_usernamefalse
initiator_idfalse
job_idfalse
max_deadlinefalse
provisioner_statefalse
reasonfalse
template_version_idtrue
transitionfalse
updated_atfalse
workspace_idfalse
| | WorkspaceProxy
|
FieldTracked
created_attrue
deletedfalse
derp_enabledtrue
derp_onlytrue
display_nametrue
icontrue
idtrue
nametrue
region_idtrue
token_hashed_secrettrue
updated_atfalse
urltrue
versiontrue
wildcard_hostnametrue
| diff --git a/enterprise/audit/table.go b/enterprise/audit/table.go index f9e74959f2a28..24f7dfa4b4fe0 100644 --- a/enterprise/audit/table.go +++ b/enterprise/audit/table.go @@ -127,6 +127,7 @@ var auditableResourcesTypes = map[any]map[string]Action{ "created_by_avatar_url": ActionIgnore, "created_by_username": ActionIgnore, "archived": ActionTrack, + "source_example_id": ActionIgnore, // Never changes. }, &database.User{}: { "id": ActionTrack, From bf74f8b7f9efb184a002ee97fcd5c098972215ca Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Mon, 2 Dec 2024 21:27:43 +0200 Subject: [PATCH 2/5] feat(site): show license utilization in general settings (#15683) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR is the first iteration towards #15297 We cannot yet show license utilization over time, so we show current license utilization. This is because we don't track user states over time. We only track the current user state. A graph over time filtering by active users would therefore not account for day to day changes in user state and be inaccurate. DB schema migrations and related updates will follow that allow us to show license utilization over time. ![image](https://github.com/user-attachments/assets/91bd6e8c-e74c-4ef5-aa6b-271fd245da37) --------- Co-authored-by: ケイラ (cherry picked from commit 7e1ac2e22b6989d6939a7b4749b142fc574294bf) --- .../GeneralSettingsPageView.stories.tsx | 72 +++++++++++++++++++ .../GeneralSettingsPageView.tsx | 38 ++++++++++ 2 files changed, 110 insertions(+) diff --git a/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.stories.tsx b/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.stories.tsx index 9147a1a5befff..05ed426d5dcc9 100644 --- a/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.stories.tsx +++ b/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.stories.tsx @@ -42,6 +42,7 @@ const meta: Meta = { deploymentDAUs: MockDeploymentDAUResponse, invalidExperiments: [], safeExperiments: [], + entitlements: undefined, }, }; @@ -136,3 +137,74 @@ export const invalidExperimentsEnabled: Story = { invalidExperiments: ["invalid"], }, }; + +export const WithLicenseUtilization: Story = { + args: { + entitlements: { + ...MockEntitlementsWithUserLimit, + features: { + ...MockEntitlementsWithUserLimit.features, + user_limit: { + ...MockEntitlementsWithUserLimit.features.user_limit, + enabled: true, + actual: 75, + limit: 100, + entitlement: "entitled", + }, + }, + }, + }, +}; + +export const HighLicenseUtilization: Story = { + args: { + entitlements: { + ...MockEntitlementsWithUserLimit, + features: { + ...MockEntitlementsWithUserLimit.features, + user_limit: { + ...MockEntitlementsWithUserLimit.features.user_limit, + enabled: true, + actual: 95, + limit: 100, + entitlement: "entitled", + }, + }, + }, + }, +}; + +export const ExceedsLicenseUtilization: Story = { + args: { + entitlements: { + ...MockEntitlementsWithUserLimit, + features: { + ...MockEntitlementsWithUserLimit.features, + user_limit: { + ...MockEntitlementsWithUserLimit.features.user_limit, + enabled: true, + actual: 100, + limit: 95, + entitlement: "entitled", + }, + }, + }, + }, +}; +export const NoLicenseLimit: Story = { + args: { + entitlements: { + ...MockEntitlementsWithUserLimit, + features: { + ...MockEntitlementsWithUserLimit.features, + user_limit: { + ...MockEntitlementsWithUserLimit.features.user_limit, + enabled: false, + actual: 0, + limit: 0, + entitlement: "entitled", + }, + }, + }, + }, +}; diff --git a/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.tsx b/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.tsx index 29edacd08d9e7..df5550d70e965 100644 --- a/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.tsx +++ b/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.tsx @@ -1,4 +1,5 @@ import AlertTitle from "@mui/material/AlertTitle"; +import LinearProgress from "@mui/material/LinearProgress"; import type { DAUsResponse, Entitlements, @@ -36,6 +37,12 @@ export const GeneralSettingsPageView: FC = ({ safeExperiments, invalidExperiments, }) => { + const licenseUtilizationPercentage = + entitlements?.features?.user_limit?.actual && + entitlements?.features?.user_limit?.limit + ? entitlements.features.user_limit.actual / + entitlements.features.user_limit.limit + : undefined; return ( <> = ({ )} + {licenseUtilizationPercentage && ( + + + + {Math.round(licenseUtilizationPercentage * 100)}% used ( + {entitlements!.features.user_limit.actual}/ + {entitlements!.features.user_limit.limit} users) + + + )} {invalidExperiments.length > 0 && ( Invalid experiments in use: From 54f76059be79675832f6b6a06930678bf9881814 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 28 Nov 2024 16:58:32 +0200 Subject: [PATCH 3/5] feat(site): warn on provisioner health during builds (#15589) This PR adds warning alerts to log drawers for templates and template versions. warning alerts for workspace builds to follow in a subsequent PR. Phrasing to be finalised. Stories added and manually verified. See screenshots below. Updating a template version with no provisioners: Screenshot 2024-11-27 at 11 06 28 Build Errors for template versions now show tags as well: Screenshot 2024-11-27 at 11 07 01 Updating a template version with provisioners that are busy or unresponsive: Screenshot 2024-11-27 at 11 06 40 Creating a new template with provisioners that are busy or unresponsive: Screenshot 2024-11-27 at 11 08 55 Creating a new template when there are no provisioners to do the build: Screenshot 2024-11-27 at 11 08 45 (cherry picked from commit 56c792ab52ea14b501eb84ee1090966cbd0b410c) --- site/src/api/api.ts | 10 ++- site/src/api/queries/organizations.ts | 18 ++--- site/src/components/Alert/Alert.tsx | 3 + .../provisioners/ProvisionerAlert.stories.tsx | 28 ++++++++ .../modules/provisioners/ProvisionerAlert.tsx | 45 +++++++++++++ .../ProvisionerStatusAlert.stories.tsx | 55 +++++++++++++++ .../provisioners/ProvisionerStatusAlert.tsx | 47 +++++++++++++ .../BuildLogsDrawer.stories.tsx | 36 ++++++++++ .../CreateTemplatePage/BuildLogsDrawer.tsx | 16 ++++- .../TemplateVersionEditor.stories.tsx | 67 +++++++++++++++++++ .../TemplateVersionEditor.tsx | 42 +++++++----- site/src/testHelpers/storybook.tsx | 25 +++++++ 12 files changed, 365 insertions(+), 27 deletions(-) create mode 100644 site/src/modules/provisioners/ProvisionerAlert.stories.tsx create mode 100644 site/src/modules/provisioners/ProvisionerAlert.tsx create mode 100644 site/src/modules/provisioners/ProvisionerStatusAlert.stories.tsx create mode 100644 site/src/modules/provisioners/ProvisionerStatusAlert.tsx diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 3ad195f2bd9e4..cfba27408e9c6 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -682,12 +682,20 @@ class ApiMethods { /** * @param organization Can be the organization's ID or name + * @param tags to filter provisioner daemons by. */ getProvisionerDaemonsByOrganization = async ( organization: string, + tags?: Record, ): Promise => { + const params = new URLSearchParams(); + + if (tags) { + params.append("tags", JSON.stringify(tags)); + } + const response = await this.axios.get( - `/api/v2/organizations/${organization}/provisionerdaemons`, + `/api/v2/organizations/${organization}/provisionerdaemons?${params.toString()}`, ); return response.data; }; diff --git a/site/src/api/queries/organizations.ts b/site/src/api/queries/organizations.ts index d1df8f409dcdf..c3f5a4ebd3ced 100644 --- a/site/src/api/queries/organizations.ts +++ b/site/src/api/queries/organizations.ts @@ -115,16 +115,18 @@ export const organizations = () => { }; }; -export const getProvisionerDaemonsKey = (organization: string) => [ - "organization", - organization, - "provisionerDaemons", -]; +export const getProvisionerDaemonsKey = ( + organization: string, + tags?: Record, +) => ["organization", organization, tags, "provisionerDaemons"]; -export const provisionerDaemons = (organization: string) => { +export const provisionerDaemons = ( + organization: string, + tags?: Record, +) => { return { - queryKey: getProvisionerDaemonsKey(organization), - queryFn: () => API.getProvisionerDaemonsByOrganization(organization), + queryKey: getProvisionerDaemonsKey(organization, tags), + queryFn: () => API.getProvisionerDaemonsByOrganization(organization, tags), }; }; diff --git a/site/src/components/Alert/Alert.tsx b/site/src/components/Alert/Alert.tsx index df741a1924fa9..7750a6bc7d1e8 100644 --- a/site/src/components/Alert/Alert.tsx +++ b/site/src/components/Alert/Alert.tsx @@ -1,4 +1,5 @@ import MuiAlert, { + type AlertColor as MuiAlertColor, type AlertProps as MuiAlertProps, // biome-ignore lint/nursery/noRestrictedImports: Used as base component } from "@mui/material/Alert"; @@ -11,6 +12,8 @@ import { useState, } from "react"; +export type AlertColor = MuiAlertColor; + export type AlertProps = MuiAlertProps & { actions?: ReactNode; dismissible?: boolean; diff --git a/site/src/modules/provisioners/ProvisionerAlert.stories.tsx b/site/src/modules/provisioners/ProvisionerAlert.stories.tsx new file mode 100644 index 0000000000000..d9ca1501d6611 --- /dev/null +++ b/site/src/modules/provisioners/ProvisionerAlert.stories.tsx @@ -0,0 +1,28 @@ +import type { Meta, StoryObj } from "@storybook/react"; +import { chromatic } from "testHelpers/chromatic"; +import { ProvisionerAlert } from "./ProvisionerAlert"; + +const meta: Meta = { + title: "modules/provisioners/ProvisionerAlert", + parameters: { + chromatic, + layout: "centered", + }, + component: ProvisionerAlert, + args: { + title: "Title", + detail: "Detail", + severity: "info", + tags: { tag: "tagValue" }, + }, +}; + +export default meta; +type Story = StoryObj; + +export const Info: Story = {}; +export const NullTags: Story = { + args: { + tags: undefined, + }, +}; diff --git a/site/src/modules/provisioners/ProvisionerAlert.tsx b/site/src/modules/provisioners/ProvisionerAlert.tsx new file mode 100644 index 0000000000000..54d9ab8473e87 --- /dev/null +++ b/site/src/modules/provisioners/ProvisionerAlert.tsx @@ -0,0 +1,45 @@ +import AlertTitle from "@mui/material/AlertTitle"; +import { Alert, type AlertColor } from "components/Alert/Alert"; +import { AlertDetail } from "components/Alert/Alert"; +import { Stack } from "components/Stack/Stack"; +import { ProvisionerTag } from "modules/provisioners/ProvisionerTag"; +import type { FC } from "react"; +interface ProvisionerAlertProps { + title: string; + detail: string; + severity: AlertColor; + tags: Record; +} + +export const ProvisionerAlert: FC = ({ + title, + detail, + severity, + tags, +}) => { + return ( + { + return { + borderRadius: 0, + border: 0, + borderBottom: `1px solid ${theme.palette.divider}`, + borderLeft: `2px solid ${theme.palette[severity].main}`, + }; + }} + > + {title} + +
{detail}
+ + {Object.entries(tags ?? {}) + .filter(([key]) => key !== "owner") + .map(([key, value]) => ( + + ))} + +
+
+ ); +}; diff --git a/site/src/modules/provisioners/ProvisionerStatusAlert.stories.tsx b/site/src/modules/provisioners/ProvisionerStatusAlert.stories.tsx new file mode 100644 index 0000000000000..d4f746e99c417 --- /dev/null +++ b/site/src/modules/provisioners/ProvisionerStatusAlert.stories.tsx @@ -0,0 +1,55 @@ +import type { Meta, StoryObj } from "@storybook/react"; +import { chromatic } from "testHelpers/chromatic"; +import { MockTemplateVersion } from "testHelpers/entities"; +import { ProvisionerStatusAlert } from "./ProvisionerStatusAlert"; + +const meta: Meta = { + title: "modules/provisioners/ProvisionerStatusAlert", + parameters: { + chromatic, + layout: "centered", + }, + component: ProvisionerStatusAlert, + args: { + matchingProvisioners: 0, + availableProvisioners: 0, + tags: MockTemplateVersion.job.tags, + }, +}; + +export default meta; +type Story = StoryObj; + +export const HealthyProvisioners: Story = { + args: { + matchingProvisioners: 1, + availableProvisioners: 1, + }, +}; + +export const UndefinedMatchingProvisioners: Story = { + args: { + matchingProvisioners: undefined, + availableProvisioners: undefined, + }, +}; + +export const UndefinedAvailableProvisioners: Story = { + args: { + matchingProvisioners: 1, + availableProvisioners: undefined, + }, +}; + +export const NoMatchingProvisioners: Story = { + args: { + matchingProvisioners: 0, + }, +}; + +export const NoAvailableProvisioners: Story = { + args: { + matchingProvisioners: 1, + availableProvisioners: 0, + }, +}; diff --git a/site/src/modules/provisioners/ProvisionerStatusAlert.tsx b/site/src/modules/provisioners/ProvisionerStatusAlert.tsx new file mode 100644 index 0000000000000..54a2b56704877 --- /dev/null +++ b/site/src/modules/provisioners/ProvisionerStatusAlert.tsx @@ -0,0 +1,47 @@ +import type { AlertColor } from "components/Alert/Alert"; +import type { FC } from "react"; +import { ProvisionerAlert } from "./ProvisionerAlert"; + +interface ProvisionerStatusAlertProps { + matchingProvisioners: number | undefined; + availableProvisioners: number | undefined; + tags: Record; +} + +export const ProvisionerStatusAlert: FC = ({ + matchingProvisioners, + availableProvisioners, + tags, +}) => { + let title: string; + let detail: string; + let severity: AlertColor; + switch (true) { + case matchingProvisioners === 0: + title = "Build pending provisioner deployment"; + detail = + "Your build has been enqueued, but there are no provisioners that accept the required tags. Once a compatible provisioner becomes available, your build will continue. Please contact your administrator."; + severity = "warning"; + break; + case availableProvisioners === 0: + title = "Build delayed"; + detail = + "Provisioners that accept the required tags have not responded for longer than expected. This may delay your build. Please contact your administrator if your build does not complete."; + severity = "warning"; + break; + default: + title = "Build enqueued"; + detail = + "Your build has been enqueued and will begin once a provisioner becomes available to process it."; + severity = "info"; + } + + return ( + + ); +}; diff --git a/site/src/pages/CreateTemplatePage/BuildLogsDrawer.stories.tsx b/site/src/pages/CreateTemplatePage/BuildLogsDrawer.stories.tsx index afc3c1321a6b4..29229fadfd0ad 100644 --- a/site/src/pages/CreateTemplatePage/BuildLogsDrawer.stories.tsx +++ b/site/src/pages/CreateTemplatePage/BuildLogsDrawer.stories.tsx @@ -34,6 +34,42 @@ export const MissingVariables: Story = { }, }; +export const NoProvisioners: Story = { + args: { + templateVersion: { + ...MockTemplateVersion, + matched_provisioners: { + count: 0, + available: 0, + }, + }, + }, +}; + +export const ProvisionersUnhealthy: Story = { + args: { + templateVersion: { + ...MockTemplateVersion, + matched_provisioners: { + count: 1, + available: 0, + }, + }, + }, +}; + +export const ProvisionersHealthy: Story = { + args: { + templateVersion: { + ...MockTemplateVersion, + matched_provisioners: { + count: 1, + available: 1, + }, + }, + }, +}; + export const Logs: Story = { args: { templateVersion: { diff --git a/site/src/pages/CreateTemplatePage/BuildLogsDrawer.tsx b/site/src/pages/CreateTemplatePage/BuildLogsDrawer.tsx index 5af38b649c695..4eb1805b60e36 100644 --- a/site/src/pages/CreateTemplatePage/BuildLogsDrawer.tsx +++ b/site/src/pages/CreateTemplatePage/BuildLogsDrawer.tsx @@ -8,6 +8,7 @@ import { visuallyHidden } from "@mui/utils"; import { JobError } from "api/queries/templates"; import type { TemplateVersion } from "api/typesGenerated"; import { Loader } from "components/Loader/Loader"; +import { ProvisionerStatusAlert } from "modules/provisioners/ProvisionerStatusAlert"; import { useWatchVersionLogs } from "modules/templates/useWatchVersionLogs"; import { WorkspaceBuildLogs } from "modules/workspaces/WorkspaceBuildLogs/WorkspaceBuildLogs"; import { type FC, useLayoutEffect, useRef } from "react"; @@ -27,6 +28,10 @@ export const BuildLogsDrawer: FC = ({ variablesSectionRef, ...drawerProps }) => { + const matchingProvisioners = templateVersion?.matched_provisioners?.count; + const availableProvisioners = + templateVersion?.matched_provisioners?.available; + const logs = useWatchVersionLogs(templateVersion); const logsContainer = useRef(null); @@ -65,6 +70,8 @@ export const BuildLogsDrawer: FC = ({ + {} + {isMissingVariables ? ( { @@ -82,7 +89,14 @@ export const BuildLogsDrawer: FC = ({ ) : ( - + <> + + + )} diff --git a/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.stories.tsx b/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.stories.tsx index 1382aa100a1dc..4b8413215c9e8 100644 --- a/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.stories.tsx +++ b/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.stories.tsx @@ -49,6 +49,73 @@ type Story = StoryObj; export const Example: Story = {}; +export const UndefinedLogs: Story = { + args: { + defaultTab: "logs", + buildLogs: undefined, + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + }, + }, +}; + +export const EmptyLogs: Story = { + args: { + defaultTab: "logs", + buildLogs: [], + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + }, + }, +}; + +export const NoProvisioners: Story = { + args: { + defaultTab: "logs", + buildLogs: [], + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + matched_provisioners: { + count: 0, + available: 0, + }, + }, + }, +}; + +export const UnavailableProvisioners: Story = { + args: { + defaultTab: "logs", + buildLogs: [], + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + matched_provisioners: { + count: 1, + available: 0, + }, + }, + }, +}; + +export const HealthyProvisioners: Story = { + args: { + defaultTab: "logs", + buildLogs: [], + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + matched_provisioners: { + count: 1, + available: 1, + }, + }, + }, +}; + export const Logs: Story = { args: { defaultTab: "logs", diff --git a/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.tsx b/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.tsx index 943370f89e2a4..858f57dd59493 100644 --- a/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.tsx +++ b/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.tsx @@ -4,7 +4,6 @@ import ArrowBackOutlined from "@mui/icons-material/ArrowBackOutlined"; import CloseOutlined from "@mui/icons-material/CloseOutlined"; import PlayArrowOutlined from "@mui/icons-material/PlayArrowOutlined"; import WarningOutlined from "@mui/icons-material/WarningOutlined"; -import AlertTitle from "@mui/material/AlertTitle"; import Button from "@mui/material/Button"; import ButtonGroup from "@mui/material/ButtonGroup"; import IconButton from "@mui/material/IconButton"; @@ -17,7 +16,7 @@ import type { VariableValue, WorkspaceResource, } from "api/typesGenerated"; -import { Alert, AlertDetail } from "components/Alert/Alert"; +import { Alert } from "components/Alert/Alert"; import { Sidebar } from "components/FullPageLayout/Sidebar"; import { Topbar, @@ -29,6 +28,8 @@ import { } from "components/FullPageLayout/Topbar"; import { Loader } from "components/Loader/Loader"; import { linkToTemplate, useLinks } from "modules/navigation"; +import { ProvisionerAlert } from "modules/provisioners/ProvisionerAlert"; +import { ProvisionerStatusAlert } from "modules/provisioners/ProvisionerStatusAlert"; import { TemplateFileTree } from "modules/templates/TemplateFiles/TemplateFileTree"; import { isBinaryData } from "modules/templates/TemplateFiles/isBinaryData"; import { TemplateResourcesTable } from "modules/templates/TemplateResourcesTable/TemplateResourcesTable"; @@ -126,6 +127,8 @@ export const TemplateVersionEditor: FC = ({ const [deleteFileOpen, setDeleteFileOpen] = useState(); const [renameFileOpen, setRenameFileOpen] = useState(); const [dirty, setDirty] = useState(false); + const matchingProvisioners = templateVersion.matched_provisioners?.count; + const availableProvisioners = templateVersion.matched_provisioners?.available; const triggerPreview = useCallback(async () => { await onPreview(fileTree); @@ -192,6 +195,8 @@ export const TemplateVersionEditor: FC = ({ linkToTemplate(template.organization_name, template.name), ); + const gotBuildLogs = buildLogs && buildLogs.length > 0; + return ( <>
@@ -581,31 +586,34 @@ export const TemplateVersionEditor: FC = ({ css={[styles.logs, styles.tabContent]} ref={logsContentRef} > - {templateVersion.job.error && ( + {templateVersion.job.error ? (
- - Error during the build - {templateVersion.job.error} - + tags={templateVersion.job.tags} + />
+ ) : ( + !gotBuildLogs && ( + <> + + + + ) )} - {buildLogs && buildLogs.length > 0 ? ( + {gotBuildLogs && ( - ) : ( - )}
)} diff --git a/site/src/testHelpers/storybook.tsx b/site/src/testHelpers/storybook.tsx index e905a9b412c2c..514d34e0265e8 100644 --- a/site/src/testHelpers/storybook.tsx +++ b/site/src/testHelpers/storybook.tsx @@ -1,6 +1,7 @@ import type { StoryContext } from "@storybook/react"; import { withDefaultFeatures } from "api/api"; import { getAuthorizationKey } from "api/queries/authCheck"; +import { getProvisionerDaemonsKey } from "api/queries/organizations"; import { hasFirstUserKey, meKey } from "api/queries/users"; import type { Entitlements } from "api/typesGenerated"; import { GlobalSnackbar } from "components/GlobalSnackbar/GlobalSnackbar"; @@ -121,6 +122,30 @@ export const withAuthProvider = (Story: FC, { parameters }: StoryContext) => { ); }; +export const withProvisioners = (Story: FC, { parameters }: StoryContext) => { + if (!parameters.organization_id) { + throw new Error( + "You forgot to add `parameters.organization_id` to your story", + ); + } + if (!parameters.provisioners) { + throw new Error( + "You forgot to add `parameters.provisioners` to your story", + ); + } + if (!parameters.tags) { + throw new Error("You forgot to add `parameters.tags` to your story"); + } + + const queryClient = useQueryClient(); + queryClient.setQueryData( + getProvisionerDaemonsKey(parameters.organization_id, parameters.tags), + parameters.provisioners, + ); + + return ; +}; + export const withGlobalSnackbar = (Story: FC) => ( <> From b359fb98ad686ca5b589e12fc566c140fb263471 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 3 Dec 2024 10:12:30 +0400 Subject: [PATCH 4/5] fix: fix goroutine leak in log streaming over websocket (#15709) fixes #14881 Our handlers for streaming logs don't read from the websocket. We don't allow the client to send us any data, but the websocket library we use requires reading from the websocket to properly handle pings and closing. Not doing so can [can cause the websocket to hang on write](https://github.com/coder/websocket/issues/405), leaking go routines which were noticed in #14881. This fixes the issue, and in process refactors our log streaming to a encoder/decoder package which provides generic types for sending JSON over websocket. I'd also like for us to upgrade to the latest https://github.com/coder/websocket but we should also upgrade our tailscale fork before doing so to avoid including two copies of the websocket library. (cherry picked from commit 148a5a359324aa530cc2b0ba8481f6b47a82b716) --- coderd/provisionerjobs.go | 9 ++-- coderd/workspaceagents.go | 24 ++++------- codersdk/provisionerdaemons.go | 33 ++------------- codersdk/workspaceagents.go | 29 ++----------- codersdk/wsjson/decoder.go | 75 ++++++++++++++++++++++++++++++++++ codersdk/wsjson/encoder.go | 42 +++++++++++++++++++ 6 files changed, 134 insertions(+), 78 deletions(-) create mode 100644 codersdk/wsjson/decoder.go create mode 100644 codersdk/wsjson/encoder.go diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index df832b810e696..3db5d7c20a4bf 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -15,6 +15,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -312,6 +313,7 @@ type logFollower struct { r *http.Request rw http.ResponseWriter conn *websocket.Conn + enc *wsjson.Encoder[codersdk.ProvisionerJobLog] jobID uuid.UUID after int64 @@ -391,6 +393,7 @@ func (f *logFollower) follow() { } defer f.conn.Close(websocket.StatusNormalClosure, "done") go httpapi.Heartbeat(f.ctx, f.conn) + f.enc = wsjson.NewEncoder[codersdk.ProvisionerJobLog](f.conn, websocket.MessageText) // query for logs once right away, so we can get historical data from before // subscription @@ -488,11 +491,7 @@ func (f *logFollower) query() error { return xerrors.Errorf("error fetching logs: %w", err) } for _, log := range logs { - logB, err := json.Marshal(convertProvisionerJobLog(log)) - if err != nil { - return xerrors.Errorf("error marshaling log: %w", err) - } - err = f.conn.Write(f.ctx, websocket.MessageText, logB) + err := f.enc.Encode(convertProvisionerJobLog(log)) if err != nil { return xerrors.Errorf("error writing to websocket: %w", err) } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 922d80f0e8085..6bc09e0e770f6 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" ) @@ -396,11 +397,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) { } go httpapi.Heartbeat(ctx, conn) - ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) - defer wsNetConn.Close() // Also closes conn. + encoder := wsjson.NewEncoder[[]codersdk.WorkspaceAgentLog](conn, websocket.MessageText) + defer encoder.Close(websocket.StatusNormalClosure) - // The Go stdlib JSON encoder appends a newline character after message write. - encoder := json.NewEncoder(wsNetConn) err = encoder.Encode(convertWorkspaceAgentLogs(logs)) if err != nil { return @@ -740,16 +739,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { }) return } - ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary) - defer nconn.Close() - - // Slurp all packets from the connection into io.Discard so pongs get sent - // by the websocket package. We don't do any reads ourselves so this is - // necessary. - go func() { - _, _ = io.Copy(io.Discard, nconn) - _ = nconn.Close() - }() + encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary) + defer encoder.Close(websocket.StatusGoingAway) go func(ctx context.Context) { // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? @@ -767,7 +758,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { err := ws.Ping(ctx) cancel() if err != nil { - _ = nconn.Close() + _ = ws.Close(websocket.StatusGoingAway, "ping failed") return } } @@ -780,9 +771,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { for { derpMap := api.DERPMap() if lastDERPMap == nil || !tailnet.CompareDERPMaps(lastDERPMap, derpMap) { - err := json.NewEncoder(nconn).Encode(derpMap) + err := encoder.Encode(derpMap) if err != nil { - _ = nconn.Close() return } lastDERPMap = derpMap diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index acd0c6955ab7f..c8bd4354df153 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/codersdk/drpc" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" ) @@ -161,36 +162,8 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after } return nil, nil, ReadBodyAsError(res) } - logs := make(chan ProvisionerJobLog) - closed := make(chan struct{}) - go func() { - defer close(closed) - defer close(logs) - defer conn.Close(websocket.StatusGoingAway, "") - var log ProvisionerJobLog - for { - msgType, msg, err := conn.Read(ctx) - if err != nil { - return - } - if msgType != websocket.MessageText { - return - } - err = json.Unmarshal(msg, &log) - if err != nil { - return - } - select { - case <-ctx.Done(): - return - case logs <- log: - } - } - }() - return logs, closeFunc(func() error { - <-closed - return nil - }), nil + d := wsjson.NewDecoder[ProvisionerJobLog](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil } // ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index eeb335b130cdd..b4aec16a83190 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -15,6 +15,7 @@ import ( "nhooyr.io/websocket" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk/wsjson" ) type WorkspaceAgentStatus string @@ -454,30 +455,6 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID, } return nil, nil, ReadBodyAsError(res) } - logChunks := make(chan []WorkspaceAgentLog, 1) - closed := make(chan struct{}) - ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText) - decoder := json.NewDecoder(wsNetConn) - go func() { - defer close(closed) - defer close(logChunks) - defer conn.Close(websocket.StatusGoingAway, "") - for { - var logs []WorkspaceAgentLog - err = decoder.Decode(&logs) - if err != nil { - return - } - select { - case <-ctx.Done(): - return - case logChunks <- logs: - } - } - }() - return logChunks, closeFunc(func() error { - _ = wsNetConn.Close() - <-closed - return nil - }), nil + d := wsjson.NewDecoder[[]WorkspaceAgentLog](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil } diff --git a/codersdk/wsjson/decoder.go b/codersdk/wsjson/decoder.go new file mode 100644 index 0000000000000..4cc7ff380a73a --- /dev/null +++ b/codersdk/wsjson/decoder.go @@ -0,0 +1,75 @@ +package wsjson + +import ( + "context" + "encoding/json" + "sync/atomic" + + "nhooyr.io/websocket" + + "cdr.dev/slog" +) + +type Decoder[T any] struct { + conn *websocket.Conn + typ websocket.MessageType + ctx context.Context + cancel context.CancelFunc + chanCalled atomic.Bool + logger slog.Logger +} + +// Chan starts the decoder reading from the websocket and returns a channel for reading the +// resulting values. The chan T is closed if the underlying websocket is closed, or we encounter an +// error. We also close the underlying websocket if we encounter an error reading or decoding. +func (d *Decoder[T]) Chan() <-chan T { + if !d.chanCalled.CompareAndSwap(false, true) { + panic("chan called more than once") + } + values := make(chan T, 1) + go func() { + defer close(values) + defer d.conn.Close(websocket.StatusGoingAway, "") + for { + // we don't use d.ctx here because it only gets canceled after closing the connection + // and a "connection closed" type error is more clear than context canceled. + typ, b, err := d.conn.Read(context.Background()) + if err != nil { + // might be benign like EOF, so just log at debug + d.logger.Debug(d.ctx, "error reading from websocket", slog.Error(err)) + return + } + if typ != d.typ { + d.logger.Error(d.ctx, "websocket type mismatch while decoding") + return + } + var value T + err = json.Unmarshal(b, &value) + if err != nil { + d.logger.Error(d.ctx, "error unmarshalling", slog.Error(err)) + return + } + select { + case values <- value: + // OK + case <-d.ctx.Done(): + return + } + } + }() + return values +} + +// nolint: revive // complains that Encoder has the same function name +func (d *Decoder[T]) Close() error { + err := d.conn.Close(websocket.StatusNormalClosure, "") + d.cancel() + return err +} + +// NewDecoder creates a JSON-over-websocket decoder for type T, which must be deserializable from +// JSON. +func NewDecoder[T any](conn *websocket.Conn, typ websocket.MessageType, logger slog.Logger) *Decoder[T] { + ctx, cancel := context.WithCancel(context.Background()) + return &Decoder[T]{conn: conn, ctx: ctx, cancel: cancel, typ: typ, logger: logger} +} diff --git a/codersdk/wsjson/encoder.go b/codersdk/wsjson/encoder.go new file mode 100644 index 0000000000000..4cde05984e690 --- /dev/null +++ b/codersdk/wsjson/encoder.go @@ -0,0 +1,42 @@ +package wsjson + +import ( + "context" + "encoding/json" + + "golang.org/x/xerrors" + "nhooyr.io/websocket" +) + +type Encoder[T any] struct { + conn *websocket.Conn + typ websocket.MessageType +} + +func (e *Encoder[T]) Encode(v T) error { + w, err := e.conn.Writer(context.Background(), e.typ) + if err != nil { + return xerrors.Errorf("get websocket writer: %w", err) + } + defer w.Close() + j := json.NewEncoder(w) + err = j.Encode(v) + if err != nil { + return xerrors.Errorf("encode json: %w", err) + } + return nil +} + +func (e *Encoder[T]) Close(c websocket.StatusCode) error { + return e.conn.Close(c, "") +} + +// NewEncoder creates a JSON-over websocket encoder for the type T, which must be JSON-serializable. +// You may then call Encode() to send objects over the websocket. Creating an Encoder closes the +// websocket for reading, turning it into a unidirectional write stream of JSON-encoded objects. +func NewEncoder[T any](conn *websocket.Conn, typ websocket.MessageType) *Encoder[T] { + // Here we close the websocket for reading, so that the websocket library will handle pings and + // close frames. + _ = conn.CloseRead(context.Background()) + return &Encoder[T]{conn: conn, typ: typ} +} From 1d789ce6566b3188f812c517ca234ed7df947f6f Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Tue, 3 Dec 2024 19:26:31 +0000 Subject: [PATCH 5/5] fix(coderd): ensure that clearing invalid oauth refresh tokens works with dbcrypt (#15721) https://github.com/coder/coder/pull/15608 introduced a buggy behaviour with dbcrypt enabled. When clearing an oauth refresh token, we had been setting the value to the empty string. The database encryption package considers decrypting an empty string to be an error, as an empty encrypted string value will still have a nonce associated with it and thus not actually be empty when stored at rest. Instead of 'deleting' the refresh token, 'update' it to be the empty string. This plays nicely with dbcrypt. It also adds a 'utility test' in the dbcrypt package to help encrypt a value. This was useful when manually fixing users affected by this bug on our dogfood instance. (cherry picked from commit e744cde86f6407e10bb03794d837e498d5df6f3d) --- coderd/database/dbauthz/dbauthz.go | 14 ++--- coderd/database/dbauthz/dbauthz_test.go | 12 +++-- coderd/database/dbmem/dbmem.go | 46 ++++++++--------- coderd/database/dbmetrics/querymetrics.go | 14 ++--- coderd/database/dbmock/dbmock.go | 28 +++++----- coderd/database/querier.go | 5 +- coderd/database/queries.sql.go | 57 ++++++++++++--------- coderd/database/queries/externalauth.sql | 15 +++--- coderd/externalauth/externalauth.go | 10 ++-- coderd/externalauth/externalauth_test.go | 2 +- enterprise/dbcrypt/cipher_internal_test.go | 34 ++++++++++++ enterprise/dbcrypt/dbcrypt.go | 15 ++++++ enterprise/dbcrypt/dbcrypt_internal_test.go | 26 ++++++++++ 13 files changed, 184 insertions(+), 94 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 58c9179da5e4b..c8e8880b79fed 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3330,13 +3330,6 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) } -func (q *querier) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error { - fetch := func(ctx context.Context, arg database.RemoveRefreshTokenParams) (database.ExternalAuthLink, error) { - return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) - } - return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.RemoveRefreshToken)(ctx, arg) -} - func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { // This is a system function to clear user groups in group sync. if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { @@ -3435,6 +3428,13 @@ func (q *querier) UpdateExternalAuthLink(ctx context.Context, arg database.Updat return fetchAndQuery(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLink)(ctx, arg) } +func (q *querier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error { + fetch := func(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) (database.ExternalAuthLink, error) { + return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) + } + return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLinkRefreshToken)(ctx, arg) +} + func (q *querier) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { fetch := func(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { return q.db.GetGitSSHKey(ctx, arg.UserID) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 638829ae24ae5..1c60018e87062 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1282,12 +1282,14 @@ func (s *MethodTestSuite) TestUser() { UserID: u.ID, }).Asserts(u, policy.ActionUpdatePersonal) })) - s.Run("RemoveRefreshToken", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpdateExternalAuthLinkRefreshToken", s.Subtest(func(db database.Store, check *expects) { link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) - check.Args(database.RemoveRefreshTokenParams{ - ProviderID: link.ProviderID, - UserID: link.UserID, - UpdatedAt: link.UpdatedAt, + check.Args(database.UpdateExternalAuthLinkRefreshTokenParams{ + OAuthRefreshToken: "", + OAuthRefreshTokenKeyID: "", + ProviderID: link.ProviderID, + UserID: link.UserID, + UpdatedAt: link.UpdatedAt, }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal) })) s.Run("UpdateExternalAuthLink", s.Subtest(func(db database.Store, check *expects) { diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 765573b311a84..385cdcfde5709 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -8556,29 +8556,6 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg return database.WorkspaceProxy{}, sql.ErrNoRows } -func (q *FakeQuerier) RemoveRefreshToken(_ context.Context, arg database.RemoveRefreshTokenParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for index, gitAuthLink := range q.externalAuthLinks { - if gitAuthLink.ProviderID != arg.ProviderID { - continue - } - if gitAuthLink.UserID != arg.UserID { - continue - } - gitAuthLink.UpdatedAt = arg.UpdatedAt - gitAuthLink.OAuthRefreshToken = "" - q.externalAuthLinks[index] = gitAuthLink - - return nil - } - return sql.ErrNoRows -} - func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() @@ -8798,6 +8775,29 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd return database.ExternalAuthLink{}, sql.ErrNoRows } +func (q *FakeQuerier) UpdateExternalAuthLinkRefreshToken(_ context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + for index, gitAuthLink := range q.externalAuthLinks { + if gitAuthLink.ProviderID != arg.ProviderID { + continue + } + if gitAuthLink.UserID != arg.UserID { + continue + } + gitAuthLink.UpdatedAt = arg.UpdatedAt + gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken + q.externalAuthLinks[index] = gitAuthLink + + return nil + } + return sql.ErrNoRows +} + func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { if err := validateDatabaseType(arg); err != nil { return database.GitSSHKey{}, err diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index efde94488828f..54dd723ae1395 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2093,13 +2093,6 @@ func (m queryMetricsStore) RegisterWorkspaceProxy(ctx context.Context, arg datab return proxy, err } -func (m queryMetricsStore) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error { - start := time.Now() - r0 := m.s.RemoveRefreshToken(ctx, arg) - m.queryLatencies.WithLabelValues("RemoveRefreshToken").Observe(time.Since(start).Seconds()) - return r0 -} - func (m queryMetricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { start := time.Now() r0 := m.s.RemoveUserFromAllGroups(ctx, userID) @@ -2170,6 +2163,13 @@ func (m queryMetricsStore) UpdateExternalAuthLink(ctx context.Context, arg datab return link, err } +func (m queryMetricsStore) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error { + start := time.Now() + r0 := m.s.UpdateExternalAuthLinkRefreshToken(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateExternalAuthLinkRefreshToken").Observe(time.Since(start).Seconds()) + return r0 +} + func (m queryMetricsStore) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { start := time.Now() key, err := m.s.UpdateGitSSHKey(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index eefa89c86b57f..064d0dfd926c8 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4463,20 +4463,6 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1) } -// RemoveRefreshToken mocks base method. -func (m *MockStore) RemoveRefreshToken(arg0 context.Context, arg1 database.RemoveRefreshTokenParams) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveRefreshToken", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveRefreshToken indicates an expected call of RemoveRefreshToken. -func (mr *MockStoreMockRecorder) RemoveRefreshToken(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveRefreshToken", reflect.TypeOf((*MockStore)(nil).RemoveRefreshToken), arg0, arg1) -} - // RemoveUserFromAllGroups mocks base method. func (m *MockStore) RemoveUserFromAllGroups(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() @@ -4622,6 +4608,20 @@ func (mr *MockStoreMockRecorder) UpdateExternalAuthLink(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateExternalAuthLink", reflect.TypeOf((*MockStore)(nil).UpdateExternalAuthLink), arg0, arg1) } +// UpdateExternalAuthLinkRefreshToken mocks base method. +func (m *MockStore) UpdateExternalAuthLinkRefreshToken(arg0 context.Context, arg1 database.UpdateExternalAuthLinkRefreshTokenParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateExternalAuthLinkRefreshToken", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateExternalAuthLinkRefreshToken indicates an expected call of UpdateExternalAuthLinkRefreshToken. +func (mr *MockStoreMockRecorder) UpdateExternalAuthLinkRefreshToken(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateExternalAuthLinkRefreshToken", reflect.TypeOf((*MockStore)(nil).UpdateExternalAuthLinkRefreshToken), arg0, arg1) +} + // UpdateGitSSHKey mocks base method. func (m *MockStore) UpdateGitSSHKey(arg0 context.Context, arg1 database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index d75b051cac330..07b8056e1a5c4 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -424,10 +424,6 @@ type sqlcQuerier interface { OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) - // Removing the refresh token disables the refresh behavior for a given - // auth token. If a refresh token is marked invalid, it is better to remove it - // then continually attempt to refresh the token. - RemoveRefreshToken(ctx context.Context, arg RemoveRefreshTokenParams) error RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error @@ -443,6 +439,7 @@ type sqlcQuerier interface { UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) + UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 33a3ce12a444d..e9fe766f31e53 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1194,29 +1194,6 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter return i, err } -const removeRefreshToken = `-- name: RemoveRefreshToken :exec -UPDATE - external_auth_links -SET - oauth_refresh_token = '', - updated_at = $1 -WHERE provider_id = $2 AND user_id = $3 -` - -type RemoveRefreshTokenParams struct { - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - ProviderID string `db:"provider_id" json:"provider_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` -} - -// Removing the refresh token disables the refresh behavior for a given -// auth token. If a refresh token is marked invalid, it is better to remove it -// then continually attempt to refresh the token. -func (q *sqlQuerier) RemoveRefreshToken(ctx context.Context, arg RemoveRefreshTokenParams) error { - _, err := q.db.ExecContext(ctx, removeRefreshToken, arg.UpdatedAt, arg.ProviderID, arg.UserID) - return err -} - const updateExternalAuthLink = `-- name: UpdateExternalAuthLink :one UPDATE external_auth_links SET updated_at = $3, @@ -1269,6 +1246,40 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter return i, err } +const updateExternalAuthLinkRefreshToken = `-- name: UpdateExternalAuthLinkRefreshToken :exec +UPDATE + external_auth_links +SET + oauth_refresh_token = $1, + updated_at = $2 +WHERE + provider_id = $3 +AND + user_id = $4 +AND + -- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id + $5 :: text = $5 :: text +` + +type UpdateExternalAuthLinkRefreshTokenParams struct { + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ProviderID string `db:"provider_id" json:"provider_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + OAuthRefreshTokenKeyID string `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` +} + +func (q *sqlQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error { + _, err := q.db.ExecContext(ctx, updateExternalAuthLinkRefreshToken, + arg.OAuthRefreshToken, + arg.UpdatedAt, + arg.ProviderID, + arg.UserID, + arg.OAuthRefreshTokenKeyID, + ) + return err +} + const getFileByHashAndCreator = `-- name: GetFileByHashAndCreator :one SELECT hash, created_at, created_by, mimetype, data, id diff --git a/coderd/database/queries/externalauth.sql b/coderd/database/queries/externalauth.sql index cd223bd792a2a..4368ce56589f0 100644 --- a/coderd/database/queries/externalauth.sql +++ b/coderd/database/queries/externalauth.sql @@ -43,13 +43,16 @@ UPDATE external_auth_links SET oauth_extra = $9 WHERE provider_id = $1 AND user_id = $2 RETURNING *; --- name: RemoveRefreshToken :exec --- Removing the refresh token disables the refresh behavior for a given --- auth token. If a refresh token is marked invalid, it is better to remove it --- then continually attempt to refresh the token. +-- name: UpdateExternalAuthLinkRefreshToken :exec UPDATE external_auth_links SET - oauth_refresh_token = '', + oauth_refresh_token = @oauth_refresh_token, updated_at = @updated_at -WHERE provider_id = @provider_id AND user_id = @user_id; +WHERE + provider_id = @provider_id +AND + user_id = @user_id +AND + -- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id + @oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text; diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 1ce850c9cec03..95ee751ca674e 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -143,10 +143,12 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu // get rid of it. Keeping it around will cause additional refresh // attempts that will fail and cost us api rate limits. if isFailedRefresh(existingToken, err) { - dbExecErr := db.RemoveRefreshToken(ctx, database.RemoveRefreshTokenParams{ - UpdatedAt: dbtime.Now(), - ProviderID: externalAuthLink.ProviderID, - UserID: externalAuthLink.UserID, + dbExecErr := db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{ + OAuthRefreshToken: "", // It is better to clear the refresh token than to keep retrying. + OAuthRefreshTokenKeyID: externalAuthLink.OAuthRefreshTokenKeyID.String, + UpdatedAt: dbtime.Now(), + ProviderID: externalAuthLink.ProviderID, + UserID: externalAuthLink.UserID, }) if dbExecErr != nil { // This error should be rare. diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 84bded9856572..d3ba2262962b6 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -190,7 +190,7 @@ func TestRefreshToken(t *testing.T) { // Try again with a bad refresh token error // Expect DB call to remove the refresh token - mDB.EXPECT().RemoveRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1) + mDB.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1) refreshErr = &oauth2.RetrieveError{ // github error Response: &http.Response{ StatusCode: http.StatusOK, diff --git a/enterprise/dbcrypt/cipher_internal_test.go b/enterprise/dbcrypt/cipher_internal_test.go index b6740de17eec6..c70796ba27e97 100644 --- a/enterprise/dbcrypt/cipher_internal_test.go +++ b/enterprise/dbcrypt/cipher_internal_test.go @@ -3,6 +3,8 @@ package dbcrypt import ( "bytes" "encoding/base64" + "os" + "strings" "testing" "github.com/stretchr/testify/require" @@ -89,3 +91,35 @@ func TestCiphersBackwardCompatibility(t *testing.T) { require.NoError(t, err, "decryption should succeed") require.Equal(t, msg, string(decrypted), "decrypted message should match original message") } + +// If you're looking here, you're probably in trouble. +// Here's what you need to do: +// 1. Get the current CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable. +// 2. Run the following command: +// ENCRYPT_ME="" CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS="" go test -v -count=1 ./enterprise/dbcrypt -test.run='^TestHelpMeEncryptSomeValue$' +// 3. Copy the value from the test output and do what you need with it. +func TestHelpMeEncryptSomeValue(t *testing.T) { + t.Parallel() + t.Skip("this only exists if you need to encrypt a value with dbcrypt, it does not actually test anything") + + valueToEncrypt := os.Getenv("ENCRYPT_ME") + t.Logf("valueToEncrypt: %q", valueToEncrypt) + keys := os.Getenv("CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS") + require.NotEmpty(t, keys, "Set the CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable to use this") + + base64Keys := strings.Split(keys, ",") + activeKey := base64Keys[0] + + decodedKey, err := base64.StdEncoding.DecodeString(activeKey) + require.NoError(t, err, "the active key should be valid base64") + + cipher, err := cipherAES256(decodedKey) + require.NoError(t, err) + + t.Logf("cipher digest: %+v", cipher.HexDigest()) + + encryptedEmptyString, err := cipher.Encrypt([]byte(valueToEncrypt)) + require.NoError(t, err) + + t.Logf("encrypted and base64-encoded: %q", base64.StdEncoding.EncodeToString(encryptedEmptyString)) +} diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index 77a7d5cb78738..e0ca58cc5231a 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -261,6 +261,21 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U return link, nil } +func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error { + // We would normally use a sql.NullString here, but sqlc does not want to make + // a params struct with a nullable string. + var digest sql.NullString + if params.OAuthRefreshTokenKeyID != "" { + digest.String = params.OAuthRefreshTokenKeyID + digest.Valid = true + } + if err := db.encryptField(¶ms.OAuthRefreshToken, &digest); err != nil { + return err + } + + return db.Store.UpdateExternalAuthLinkRefreshToken(ctx, params) +} + func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { keys, err := db.Store.GetCryptoKeys(ctx) if err != nil { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 3e252496d6a69..e73c3eee85c16 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -17,6 +17,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" ) func TestUserLinks(t *testing.T) { @@ -96,6 +97,31 @@ func TestUserLinks(t *testing.T) { require.EqualValues(t, expectedClaims, rawLink.Claims) }) + t.Run("UpdateExternalAuthLinkRefreshToken", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.ExternalAuthLink(t, crypt, database.ExternalAuthLink{ + UserID: user.ID, + }) + + err := crypt.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{ + OAuthRefreshToken: "", + OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID.String, + UpdatedAt: dbtime.Now(), + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + + rawLink, err := db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "") + }) + t.Run("GetUserLinkByLinkedID", func(t *testing.T) { t.Parallel() t.Run("OK", func(t *testing.T) {