diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 58c9179da5e4b..c8e8880b79fed 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3330,13 +3330,6 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg) } -func (q *querier) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error { - fetch := func(ctx context.Context, arg database.RemoveRefreshTokenParams) (database.ExternalAuthLink, error) { - return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) - } - return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.RemoveRefreshToken)(ctx, arg) -} - func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { // This is a system function to clear user groups in group sync. if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { @@ -3435,6 +3428,13 @@ func (q *querier) UpdateExternalAuthLink(ctx context.Context, arg database.Updat return fetchAndQuery(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLink)(ctx, arg) } +func (q *querier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error { + fetch := func(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) (database.ExternalAuthLink, error) { + return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) + } + return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateExternalAuthLinkRefreshToken)(ctx, arg) +} + func (q *querier) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { fetch := func(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { return q.db.GetGitSSHKey(ctx, arg.UserID) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 638829ae24ae5..1c60018e87062 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1282,12 +1282,14 @@ func (s *MethodTestSuite) TestUser() { UserID: u.ID, }).Asserts(u, policy.ActionUpdatePersonal) })) - s.Run("RemoveRefreshToken", s.Subtest(func(db database.Store, check *expects) { + s.Run("UpdateExternalAuthLinkRefreshToken", s.Subtest(func(db database.Store, check *expects) { link := dbgen.ExternalAuthLink(s.T(), db, database.ExternalAuthLink{}) - check.Args(database.RemoveRefreshTokenParams{ - ProviderID: link.ProviderID, - UserID: link.UserID, - UpdatedAt: link.UpdatedAt, + check.Args(database.UpdateExternalAuthLinkRefreshTokenParams{ + OAuthRefreshToken: "", + OAuthRefreshTokenKeyID: "", + ProviderID: link.ProviderID, + UserID: link.UserID, + UpdatedAt: link.UpdatedAt, }).Asserts(rbac.ResourceUserObject(link.UserID), policy.ActionUpdatePersonal) })) s.Run("UpdateExternalAuthLink", s.Subtest(func(db database.Store, check *expects) { diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index ae898d4f1fdc3..9c8696112dea8 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -788,16 +788,17 @@ func TemplateVersion(t testing.TB, db database.Store, orig database.TemplateVers err := db.InTx(func(db database.Store) error { versionID := takeFirst(orig.ID, uuid.New()) err := db.InsertTemplateVersion(genCtx, database.InsertTemplateVersionParams{ - ID: versionID, - TemplateID: takeFirst(orig.TemplateID, uuid.NullUUID{}), - OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), - CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()), - UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()), - Name: takeFirst(orig.Name, testutil.GetRandomName(t)), - Message: orig.Message, - Readme: takeFirst(orig.Readme, testutil.GetRandomName(t)), - JobID: takeFirst(orig.JobID, uuid.New()), - CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + ID: versionID, + TemplateID: takeFirst(orig.TemplateID, uuid.NullUUID{}), + OrganizationID: takeFirst(orig.OrganizationID, uuid.New()), + CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()), + UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()), + Name: takeFirst(orig.Name, testutil.GetRandomName(t)), + Message: orig.Message, + Readme: takeFirst(orig.Readme, testutil.GetRandomName(t)), + JobID: takeFirst(orig.JobID, uuid.New()), + CreatedBy: takeFirst(orig.CreatedBy, uuid.New()), + SourceExampleID: takeFirst(orig.SourceExampleID, sql.NullString{}), }) if err != nil { return err diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 5583fff111648..385cdcfde5709 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -7699,16 +7699,17 @@ func (q *FakeQuerier) InsertTemplateVersion(_ context.Context, arg database.Inse //nolint:gosimple version := database.TemplateVersionTable{ - ID: arg.ID, - TemplateID: arg.TemplateID, - OrganizationID: arg.OrganizationID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Name: arg.Name, - Message: arg.Message, - Readme: arg.Readme, - JobID: arg.JobID, - CreatedBy: arg.CreatedBy, + ID: arg.ID, + TemplateID: arg.TemplateID, + OrganizationID: arg.OrganizationID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Name: arg.Name, + Message: arg.Message, + Readme: arg.Readme, + JobID: arg.JobID, + CreatedBy: arg.CreatedBy, + SourceExampleID: arg.SourceExampleID, } q.templateVersions = append(q.templateVersions, version) return nil @@ -8555,29 +8556,6 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg return database.WorkspaceProxy{}, sql.ErrNoRows } -func (q *FakeQuerier) RemoveRefreshToken(_ context.Context, arg database.RemoveRefreshTokenParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - for index, gitAuthLink := range q.externalAuthLinks { - if gitAuthLink.ProviderID != arg.ProviderID { - continue - } - if gitAuthLink.UserID != arg.UserID { - continue - } - gitAuthLink.UpdatedAt = arg.UpdatedAt - gitAuthLink.OAuthRefreshToken = "" - q.externalAuthLinks[index] = gitAuthLink - - return nil - } - return sql.ErrNoRows -} - func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() @@ -8797,6 +8775,29 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd return database.ExternalAuthLink{}, sql.ErrNoRows } +func (q *FakeQuerier) UpdateExternalAuthLinkRefreshToken(_ context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + for index, gitAuthLink := range q.externalAuthLinks { + if gitAuthLink.ProviderID != arg.ProviderID { + continue + } + if gitAuthLink.UserID != arg.UserID { + continue + } + gitAuthLink.UpdatedAt = arg.UpdatedAt + gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken + q.externalAuthLinks[index] = gitAuthLink + + return nil + } + return sql.ErrNoRows +} + func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { if err := validateDatabaseType(arg); err != nil { return database.GitSSHKey{}, err diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index efde94488828f..54dd723ae1395 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2093,13 +2093,6 @@ func (m queryMetricsStore) RegisterWorkspaceProxy(ctx context.Context, arg datab return proxy, err } -func (m queryMetricsStore) RemoveRefreshToken(ctx context.Context, arg database.RemoveRefreshTokenParams) error { - start := time.Now() - r0 := m.s.RemoveRefreshToken(ctx, arg) - m.queryLatencies.WithLabelValues("RemoveRefreshToken").Observe(time.Since(start).Seconds()) - return r0 -} - func (m queryMetricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error { start := time.Now() r0 := m.s.RemoveUserFromAllGroups(ctx, userID) @@ -2170,6 +2163,13 @@ func (m queryMetricsStore) UpdateExternalAuthLink(ctx context.Context, arg datab return link, err } +func (m queryMetricsStore) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg database.UpdateExternalAuthLinkRefreshTokenParams) error { + start := time.Now() + r0 := m.s.UpdateExternalAuthLinkRefreshToken(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateExternalAuthLinkRefreshToken").Observe(time.Since(start).Seconds()) + return r0 +} + func (m queryMetricsStore) UpdateGitSSHKey(ctx context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { start := time.Now() key, err := m.s.UpdateGitSSHKey(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index eefa89c86b57f..064d0dfd926c8 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4463,20 +4463,6 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), arg0, arg1) } -// RemoveRefreshToken mocks base method. -func (m *MockStore) RemoveRefreshToken(arg0 context.Context, arg1 database.RemoveRefreshTokenParams) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveRefreshToken", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveRefreshToken indicates an expected call of RemoveRefreshToken. -func (mr *MockStoreMockRecorder) RemoveRefreshToken(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveRefreshToken", reflect.TypeOf((*MockStore)(nil).RemoveRefreshToken), arg0, arg1) -} - // RemoveUserFromAllGroups mocks base method. func (m *MockStore) RemoveUserFromAllGroups(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() @@ -4622,6 +4608,20 @@ func (mr *MockStoreMockRecorder) UpdateExternalAuthLink(arg0, arg1 any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateExternalAuthLink", reflect.TypeOf((*MockStore)(nil).UpdateExternalAuthLink), arg0, arg1) } +// UpdateExternalAuthLinkRefreshToken mocks base method. +func (m *MockStore) UpdateExternalAuthLinkRefreshToken(arg0 context.Context, arg1 database.UpdateExternalAuthLinkRefreshTokenParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateExternalAuthLinkRefreshToken", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateExternalAuthLinkRefreshToken indicates an expected call of UpdateExternalAuthLinkRefreshToken. +func (mr *MockStoreMockRecorder) UpdateExternalAuthLinkRefreshToken(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateExternalAuthLinkRefreshToken", reflect.TypeOf((*MockStore)(nil).UpdateExternalAuthLinkRefreshToken), arg0, arg1) +} + // UpdateGitSSHKey mocks base method. func (m *MockStore) UpdateGitSSHKey(arg0 context.Context, arg1 database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 9919011579bde..eba9b7cf106d3 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1217,7 +1217,8 @@ CREATE TABLE template_versions ( created_by uuid NOT NULL, external_auth_providers jsonb DEFAULT '[]'::jsonb NOT NULL, message character varying(1048576) DEFAULT ''::character varying NOT NULL, - archived boolean DEFAULT false NOT NULL + archived boolean DEFAULT false NOT NULL, + source_example_id text ); COMMENT ON COLUMN template_versions.external_auth_providers IS 'IDs of External auth providers for a specific template version'; @@ -1245,6 +1246,7 @@ CREATE VIEW template_version_with_user AS template_versions.external_auth_providers, template_versions.message, template_versions.archived, + template_versions.source_example_id, COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url, COALESCE(visible_users.username, ''::text) AS created_by_username FROM (template_versions diff --git a/coderd/database/migrations/000277_template_version_example_ids.down.sql b/coderd/database/migrations/000277_template_version_example_ids.down.sql new file mode 100644 index 0000000000000..ad961e9f635c7 --- /dev/null +++ b/coderd/database/migrations/000277_template_version_example_ids.down.sql @@ -0,0 +1,28 @@ +-- We cannot alter the column type while a view depends on it, so we drop it and recreate it. +DROP VIEW template_version_with_user; + +ALTER TABLE + template_versions +DROP COLUMN source_example_id; + +-- Recreate `template_version_with_user` as described in dump.sql +CREATE VIEW template_version_with_user AS +SELECT + template_versions.id, + template_versions.template_id, + template_versions.organization_id, + template_versions.created_at, + template_versions.updated_at, + template_versions.name, + template_versions.readme, + template_versions.job_id, + template_versions.created_by, + template_versions.external_auth_providers, + template_versions.message, + template_versions.archived, + COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url, + COALESCE(visible_users.username, ''::text) AS created_by_username +FROM (template_versions + LEFT JOIN visible_users ON (template_versions.created_by = visible_users.id)); + +COMMENT ON VIEW template_version_with_user IS 'Joins in the username + avatar url of the created by user.'; diff --git a/coderd/database/migrations/000277_template_version_example_ids.up.sql b/coderd/database/migrations/000277_template_version_example_ids.up.sql new file mode 100644 index 0000000000000..aca34b31de5dc --- /dev/null +++ b/coderd/database/migrations/000277_template_version_example_ids.up.sql @@ -0,0 +1,30 @@ +-- We cannot alter the column type while a view depends on it, so we drop it and recreate it. +DROP VIEW template_version_with_user; + +ALTER TABLE + template_versions +ADD + COLUMN source_example_id TEXT; + +-- Recreate `template_version_with_user` as described in dump.sql +CREATE VIEW template_version_with_user AS +SELECT + template_versions.id, + template_versions.template_id, + template_versions.organization_id, + template_versions.created_at, + template_versions.updated_at, + template_versions.name, + template_versions.readme, + template_versions.job_id, + template_versions.created_by, + template_versions.external_auth_providers, + template_versions.message, + template_versions.archived, + template_versions.source_example_id, + COALESCE(visible_users.avatar_url, ''::text) AS created_by_avatar_url, + COALESCE(visible_users.username, ''::text) AS created_by_username +FROM (template_versions + LEFT JOIN visible_users ON (template_versions.created_by = visible_users.id)); + +COMMENT ON VIEW template_version_with_user IS 'Joins in the username + avatar url of the created by user.'; diff --git a/coderd/database/models.go b/coderd/database/models.go index af0a3122f7964..6b99245079950 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -2773,6 +2773,7 @@ type TemplateVersion struct { ExternalAuthProviders json.RawMessage `db:"external_auth_providers" json:"external_auth_providers"` Message string `db:"message" json:"message"` Archived bool `db:"archived" json:"archived"` + SourceExampleID sql.NullString `db:"source_example_id" json:"source_example_id"` CreatedByAvatarURL string `db:"created_by_avatar_url" json:"created_by_avatar_url"` CreatedByUsername string `db:"created_by_username" json:"created_by_username"` } @@ -2826,8 +2827,9 @@ type TemplateVersionTable struct { // IDs of External auth providers for a specific template version ExternalAuthProviders json.RawMessage `db:"external_auth_providers" json:"external_auth_providers"` // Message describing the changes in this version of the template, similar to a Git commit message. Like a commit message, this should be a short, high-level description of the changes in this version of the template. This message is immutable and should not be updated after the fact. - Message string `db:"message" json:"message"` - Archived bool `db:"archived" json:"archived"` + Message string `db:"message" json:"message"` + Archived bool `db:"archived" json:"archived"` + SourceExampleID sql.NullString `db:"source_example_id" json:"source_example_id"` } type TemplateVersionVariable struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index d75b051cac330..07b8056e1a5c4 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -424,10 +424,6 @@ type sqlcQuerier interface { OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) - // Removing the refresh token disables the refresh behavior for a given - // auth token. If a refresh token is marked invalid, it is better to remove it - // then continually attempt to refresh the token. - RemoveRefreshToken(ctx context.Context, arg RemoveRefreshTokenParams) error RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error @@ -443,6 +439,7 @@ type sqlcQuerier interface { UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) + UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) (GitSSHKey, error) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) UpdateInactiveUsersToDormant(ctx context.Context, arg UpdateInactiveUsersToDormantParams) ([]UpdateInactiveUsersToDormantRow, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 4eec78cf97fba..e9fe766f31e53 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1194,29 +1194,6 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter return i, err } -const removeRefreshToken = `-- name: RemoveRefreshToken :exec -UPDATE - external_auth_links -SET - oauth_refresh_token = '', - updated_at = $1 -WHERE provider_id = $2 AND user_id = $3 -` - -type RemoveRefreshTokenParams struct { - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - ProviderID string `db:"provider_id" json:"provider_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` -} - -// Removing the refresh token disables the refresh behavior for a given -// auth token. If a refresh token is marked invalid, it is better to remove it -// then continually attempt to refresh the token. -func (q *sqlQuerier) RemoveRefreshToken(ctx context.Context, arg RemoveRefreshTokenParams) error { - _, err := q.db.ExecContext(ctx, removeRefreshToken, arg.UpdatedAt, arg.ProviderID, arg.UserID) - return err -} - const updateExternalAuthLink = `-- name: UpdateExternalAuthLink :one UPDATE external_auth_links SET updated_at = $3, @@ -1269,6 +1246,40 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter return i, err } +const updateExternalAuthLinkRefreshToken = `-- name: UpdateExternalAuthLinkRefreshToken :exec +UPDATE + external_auth_links +SET + oauth_refresh_token = $1, + updated_at = $2 +WHERE + provider_id = $3 +AND + user_id = $4 +AND + -- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id + $5 :: text = $5 :: text +` + +type UpdateExternalAuthLinkRefreshTokenParams struct { + OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ProviderID string `db:"provider_id" json:"provider_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + OAuthRefreshTokenKeyID string `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"` +} + +func (q *sqlQuerier) UpdateExternalAuthLinkRefreshToken(ctx context.Context, arg UpdateExternalAuthLinkRefreshTokenParams) error { + _, err := q.db.ExecContext(ctx, updateExternalAuthLinkRefreshToken, + arg.OAuthRefreshToken, + arg.UpdatedAt, + arg.ProviderID, + arg.UserID, + arg.OAuthRefreshTokenKeyID, + ) + return err +} + const getFileByHashAndCreator = `-- name: GetFileByHashAndCreator :one SELECT hash, created_at, created_by, mimetype, data, id @@ -8996,7 +9007,7 @@ FROM -- Scope an archive to a single template and ignore already archived template versions ( SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id FROM template_versions WHERE @@ -9097,7 +9108,7 @@ func (q *sqlQuerier) ArchiveUnusedTemplateVersions(ctx context.Context, arg Arch const getPreviousTemplateVersion = `-- name: GetPreviousTemplateVersion :one SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9134,6 +9145,7 @@ func (q *sqlQuerier) GetPreviousTemplateVersion(ctx context.Context, arg GetPrev &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ) @@ -9142,7 +9154,7 @@ func (q *sqlQuerier) GetPreviousTemplateVersion(ctx context.Context, arg GetPrev const getTemplateVersionByID = `-- name: GetTemplateVersionByID :one SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9165,6 +9177,7 @@ func (q *sqlQuerier) GetTemplateVersionByID(ctx context.Context, id uuid.UUID) ( &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ) @@ -9173,7 +9186,7 @@ func (q *sqlQuerier) GetTemplateVersionByID(ctx context.Context, id uuid.UUID) ( const getTemplateVersionByJobID = `-- name: GetTemplateVersionByJobID :one SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9196,6 +9209,7 @@ func (q *sqlQuerier) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.U &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ) @@ -9204,7 +9218,7 @@ func (q *sqlQuerier) GetTemplateVersionByJobID(ctx context.Context, jobID uuid.U const getTemplateVersionByTemplateIDAndName = `-- name: GetTemplateVersionByTemplateIDAndName :one SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9233,6 +9247,7 @@ func (q *sqlQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ) @@ -9241,7 +9256,7 @@ func (q *sqlQuerier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, const getTemplateVersionsByIDs = `-- name: GetTemplateVersionsByIDs :many SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9270,6 +9285,7 @@ func (q *sqlQuerier) GetTemplateVersionsByIDs(ctx context.Context, ids []uuid.UU &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ); err != nil { @@ -9288,7 +9304,7 @@ func (q *sqlQuerier) GetTemplateVersionsByIDs(ctx context.Context, ids []uuid.UU const getTemplateVersionsByTemplateID = `-- name: GetTemplateVersionsByTemplateID :many SELECT - id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username + id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE @@ -9364,6 +9380,7 @@ func (q *sqlQuerier) GetTemplateVersionsByTemplateID(ctx context.Context, arg Ge &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ); err != nil { @@ -9381,7 +9398,7 @@ func (q *sqlQuerier) GetTemplateVersionsByTemplateID(ctx context.Context, arg Ge } const getTemplateVersionsCreatedAfter = `-- name: GetTemplateVersionsCreatedAfter :many -SELECT id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE created_at > $1 +SELECT id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id, created_by_avatar_url, created_by_username FROM template_version_with_user AS template_versions WHERE created_at > $1 ` func (q *sqlQuerier) GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error) { @@ -9406,6 +9423,7 @@ func (q *sqlQuerier) GetTemplateVersionsCreatedAfter(ctx context.Context, create &i.ExternalAuthProviders, &i.Message, &i.Archived, + &i.SourceExampleID, &i.CreatedByAvatarURL, &i.CreatedByUsername, ); err != nil { @@ -9434,23 +9452,25 @@ INSERT INTO message, readme, job_id, - created_by + created_by, + source_example_id ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ` type InsertTemplateVersionParams struct { - ID uuid.UUID `db:"id" json:"id"` - TemplateID uuid.NullUUID `db:"template_id" json:"template_id"` - OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - Name string `db:"name" json:"name"` - Message string `db:"message" json:"message"` - Readme string `db:"readme" json:"readme"` - JobID uuid.UUID `db:"job_id" json:"job_id"` - CreatedBy uuid.UUID `db:"created_by" json:"created_by"` + ID uuid.UUID `db:"id" json:"id"` + TemplateID uuid.NullUUID `db:"template_id" json:"template_id"` + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Name string `db:"name" json:"name"` + Message string `db:"message" json:"message"` + Readme string `db:"readme" json:"readme"` + JobID uuid.UUID `db:"job_id" json:"job_id"` + CreatedBy uuid.UUID `db:"created_by" json:"created_by"` + SourceExampleID sql.NullString `db:"source_example_id" json:"source_example_id"` } func (q *sqlQuerier) InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) error { @@ -9465,6 +9485,7 @@ func (q *sqlQuerier) InsertTemplateVersion(ctx context.Context, arg InsertTempla arg.Readme, arg.JobID, arg.CreatedBy, + arg.SourceExampleID, ) return err } diff --git a/coderd/database/queries/externalauth.sql b/coderd/database/queries/externalauth.sql index cd223bd792a2a..4368ce56589f0 100644 --- a/coderd/database/queries/externalauth.sql +++ b/coderd/database/queries/externalauth.sql @@ -43,13 +43,16 @@ UPDATE external_auth_links SET oauth_extra = $9 WHERE provider_id = $1 AND user_id = $2 RETURNING *; --- name: RemoveRefreshToken :exec --- Removing the refresh token disables the refresh behavior for a given --- auth token. If a refresh token is marked invalid, it is better to remove it --- then continually attempt to refresh the token. +-- name: UpdateExternalAuthLinkRefreshToken :exec UPDATE external_auth_links SET - oauth_refresh_token = '', + oauth_refresh_token = @oauth_refresh_token, updated_at = @updated_at -WHERE provider_id = @provider_id AND user_id = @user_id; +WHERE + provider_id = @provider_id +AND + user_id = @user_id +AND + -- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id + @oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text; diff --git a/coderd/database/queries/templateversions.sql b/coderd/database/queries/templateversions.sql index 094c1b6014de7..0436a7f9ba3b9 100644 --- a/coderd/database/queries/templateversions.sql +++ b/coderd/database/queries/templateversions.sql @@ -87,10 +87,11 @@ INSERT INTO message, readme, job_id, - created_by + created_by, + source_example_id ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10); + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); -- name: UpdateTemplateVersionByID :exec UPDATE diff --git a/coderd/externalauth/externalauth.go b/coderd/externalauth/externalauth.go index 1ce850c9cec03..95ee751ca674e 100644 --- a/coderd/externalauth/externalauth.go +++ b/coderd/externalauth/externalauth.go @@ -143,10 +143,12 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu // get rid of it. Keeping it around will cause additional refresh // attempts that will fail and cost us api rate limits. if isFailedRefresh(existingToken, err) { - dbExecErr := db.RemoveRefreshToken(ctx, database.RemoveRefreshTokenParams{ - UpdatedAt: dbtime.Now(), - ProviderID: externalAuthLink.ProviderID, - UserID: externalAuthLink.UserID, + dbExecErr := db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{ + OAuthRefreshToken: "", // It is better to clear the refresh token than to keep retrying. + OAuthRefreshTokenKeyID: externalAuthLink.OAuthRefreshTokenKeyID.String, + UpdatedAt: dbtime.Now(), + ProviderID: externalAuthLink.ProviderID, + UserID: externalAuthLink.UserID, }) if dbExecErr != nil { // This error should be rare. diff --git a/coderd/externalauth/externalauth_test.go b/coderd/externalauth/externalauth_test.go index 84bded9856572..d3ba2262962b6 100644 --- a/coderd/externalauth/externalauth_test.go +++ b/coderd/externalauth/externalauth_test.go @@ -190,7 +190,7 @@ func TestRefreshToken(t *testing.T) { // Try again with a bad refresh token error // Expect DB call to remove the refresh token - mDB.EXPECT().RemoveRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1) + mDB.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), gomock.Any()).Return(nil).Times(1) refreshErr = &oauth2.RetrieveError{ // github error Response: &http.Response{ StatusCode: http.StatusOK, diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index df832b810e696..3db5d7c20a4bf 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -15,6 +15,7 @@ import ( "nhooyr.io/websocket" "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -312,6 +313,7 @@ type logFollower struct { r *http.Request rw http.ResponseWriter conn *websocket.Conn + enc *wsjson.Encoder[codersdk.ProvisionerJobLog] jobID uuid.UUID after int64 @@ -391,6 +393,7 @@ func (f *logFollower) follow() { } defer f.conn.Close(websocket.StatusNormalClosure, "done") go httpapi.Heartbeat(f.ctx, f.conn) + f.enc = wsjson.NewEncoder[codersdk.ProvisionerJobLog](f.conn, websocket.MessageText) // query for logs once right away, so we can get historical data from before // subscription @@ -488,11 +491,7 @@ func (f *logFollower) query() error { return xerrors.Errorf("error fetching logs: %w", err) } for _, log := range logs { - logB, err := json.Marshal(convertProvisionerJobLog(log)) - if err != nil { - return xerrors.Errorf("error marshaling log: %w", err) - } - err = f.conn.Write(f.ctx, websocket.MessageText, logB) + err := f.enc.Encode(convertProvisionerJobLog(log)) if err != nil { return xerrors.Errorf("error writing to websocket: %w", err) } diff --git a/coderd/telemetry/telemetry.go b/coderd/telemetry/telemetry.go index 8ad85b0b39982..233450c43d943 100644 --- a/coderd/telemetry/telemetry.go +++ b/coderd/telemetry/telemetry.go @@ -868,6 +868,9 @@ func ConvertTemplateVersion(version database.TemplateVersion) TemplateVersion { if version.TemplateID.Valid { snapVersion.TemplateID = &version.TemplateID.UUID } + if version.SourceExampleID.Valid { + snapVersion.SourceExampleID = &version.SourceExampleID.String + } return snapVersion } @@ -1116,11 +1119,12 @@ type Template struct { } type TemplateVersion struct { - ID uuid.UUID `json:"id"` - CreatedAt time.Time `json:"created_at"` - TemplateID *uuid.UUID `json:"template_id,omitempty"` - OrganizationID uuid.UUID `json:"organization_id"` - JobID uuid.UUID `json:"job_id"` + ID uuid.UUID `json:"id"` + CreatedAt time.Time `json:"created_at"` + TemplateID *uuid.UUID `json:"template_id,omitempty"` + OrganizationID uuid.UUID `json:"organization_id"` + JobID uuid.UUID `json:"job_id"` + SourceExampleID *string `json:"source_example_id,omitempty"` } type ProvisionerJob struct { diff --git a/coderd/telemetry/telemetry_test.go b/coderd/telemetry/telemetry_test.go index 214d111a170a1..2b70cd2a6d2c3 100644 --- a/coderd/telemetry/telemetry_test.go +++ b/coderd/telemetry/telemetry_test.go @@ -1,6 +1,7 @@ package telemetry_test import ( + "database/sql" "encoding/json" "net/http" "net/http/httptest" @@ -48,6 +49,10 @@ func TestTelemetry(t *testing.T) { _ = dbgen.Template(t, db, database.Template{ Provisioner: database.ProvisionerTypeTerraform, }) + sourceExampleID := uuid.NewString() + _ = dbgen.TemplateVersion(t, db, database.TemplateVersion{ + SourceExampleID: sql.NullString{String: sourceExampleID, Valid: true}, + }) _ = dbgen.TemplateVersion(t, db, database.TemplateVersion{}) user := dbgen.User(t, db, database.User{}) _ = dbgen.Workspace(t, db, database.WorkspaceTable{}) @@ -93,7 +98,7 @@ func TestTelemetry(t *testing.T) { require.Len(t, snapshot.ProvisionerJobs, 1) require.Len(t, snapshot.Licenses, 1) require.Len(t, snapshot.Templates, 1) - require.Len(t, snapshot.TemplateVersions, 1) + require.Len(t, snapshot.TemplateVersions, 2) require.Len(t, snapshot.Users, 1) require.Len(t, snapshot.Groups, 2) // 1 member in the everyone group + 1 member in the custom group @@ -111,6 +116,17 @@ func TestTelemetry(t *testing.T) { require.Len(t, wsa.Subsystems, 2) require.Equal(t, string(database.WorkspaceAgentSubsystemEnvbox), wsa.Subsystems[0]) require.Equal(t, string(database.WorkspaceAgentSubsystemExectrace), wsa.Subsystems[1]) + + tvs := snapshot.TemplateVersions + sort.Slice(tvs, func(i, j int) bool { + // Sort by SourceExampleID presence (non-nil comes before nil) + if (tvs[i].SourceExampleID != nil) != (tvs[j].SourceExampleID != nil) { + return tvs[i].SourceExampleID != nil + } + return false + }) + require.Equal(t, tvs[0].SourceExampleID, &sourceExampleID) + require.Nil(t, tvs[1].SourceExampleID) }) t.Run("HashedEmail", func(t *testing.T) { t.Parallel() diff --git a/coderd/templateversions.go b/coderd/templateversions.go index a0609c42c33f9..12def3e5d681b 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -1582,6 +1582,10 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht Readme: "", JobID: provisionerJob.ID, CreatedBy: apiKey.UserID, + SourceExampleID: sql.NullString{ + String: req.ExampleID, + Valid: req.ExampleID != "", + }, }) if err != nil { if database.IsUniqueViolation(err, database.UniqueTemplateVersionsTemplateIDNameKey) { diff --git a/coderd/templateversions_test.go b/coderd/templateversions_test.go index 5ebbd0f41804f..5e96de10d5058 100644 --- a/coderd/templateversions_test.go +++ b/coderd/templateversions_test.go @@ -16,6 +16,7 @@ import ( "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/externalauth" "github.com/coder/coder/v2/coderd/rbac" @@ -134,7 +135,7 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) { t.Run("WithParameters", func(t *testing.T) { t.Parallel() auditor := audit.NewMock() - client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true, Auditor: auditor}) + client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true, Auditor: auditor}) user := coderdtest.CreateFirstUser(t, client) data, err := echo.Tar(&echo.Responses{ Parse: echo.ParseComplete, @@ -160,11 +161,17 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) { require.Len(t, auditor.AuditLogs(), 2) assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs()[1].Action) + + admin, err := client.User(ctx, user.UserID.String()) + require.NoError(t, err) + tvDB, err := db.GetTemplateVersionByID(dbauthz.As(ctx, coderdtest.AuthzUserSubject(admin, user.OrganizationID)), version.ID) + require.NoError(t, err) + require.False(t, tvDB.SourceExampleID.Valid) }) t.Run("Example", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, nil) + client, db := coderdtest.NewWithDatabase(t, nil) user := coderdtest.CreateFirstUser(t, client) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) @@ -205,6 +212,12 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) { require.NoError(t, err) require.Equal(t, "my-example", tv.Name) + admin, err := client.User(ctx, user.UserID.String()) + require.NoError(t, err) + tvDB, err := db.GetTemplateVersionByID(dbauthz.As(ctx, coderdtest.AuthzUserSubject(admin, user.OrganizationID)), tv.ID) + require.NoError(t, err) + require.Equal(t, ls[0].ID, tvDB.SourceExampleID.String) + // ensure the template tar was uploaded correctly fl, ct, err := client.Download(ctx, tv.Job.FileID) require.NoError(t, err) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 922d80f0e8085..6bc09e0e770f6 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" ) @@ -396,11 +397,9 @@ func (api *API) workspaceAgentLogs(rw http.ResponseWriter, r *http.Request) { } go httpapi.Heartbeat(ctx, conn) - ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) - defer wsNetConn.Close() // Also closes conn. + encoder := wsjson.NewEncoder[[]codersdk.WorkspaceAgentLog](conn, websocket.MessageText) + defer encoder.Close(websocket.StatusNormalClosure) - // The Go stdlib JSON encoder appends a newline character after message write. - encoder := json.NewEncoder(wsNetConn) err = encoder.Encode(convertWorkspaceAgentLogs(logs)) if err != nil { return @@ -740,16 +739,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { }) return } - ctx, nconn := codersdk.WebsocketNetConn(ctx, ws, websocket.MessageBinary) - defer nconn.Close() - - // Slurp all packets from the connection into io.Discard so pongs get sent - // by the websocket package. We don't do any reads ourselves so this is - // necessary. - go func() { - _, _ = io.Copy(io.Discard, nconn) - _ = nconn.Close() - }() + encoder := wsjson.NewEncoder[*tailcfg.DERPMap](ws, websocket.MessageBinary) + defer encoder.Close(websocket.StatusGoingAway) go func(ctx context.Context) { // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? @@ -767,7 +758,7 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { err := ws.Ping(ctx) cancel() if err != nil { - _ = nconn.Close() + _ = ws.Close(websocket.StatusGoingAway, "ping failed") return } } @@ -780,9 +771,8 @@ func (api *API) derpMapUpdates(rw http.ResponseWriter, r *http.Request) { for { derpMap := api.DERPMap() if lastDERPMap == nil || !tailnet.CompareDERPMaps(lastDERPMap, derpMap) { - err := json.NewEncoder(nconn).Encode(derpMap) + err := encoder.Encode(derpMap) if err != nil { - _ = nconn.Close() return } lastDERPMap = derpMap diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index acd0c6955ab7f..c8bd4354df153 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -19,6 +19,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/codersdk/drpc" + "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" ) @@ -161,36 +162,8 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after } return nil, nil, ReadBodyAsError(res) } - logs := make(chan ProvisionerJobLog) - closed := make(chan struct{}) - go func() { - defer close(closed) - defer close(logs) - defer conn.Close(websocket.StatusGoingAway, "") - var log ProvisionerJobLog - for { - msgType, msg, err := conn.Read(ctx) - if err != nil { - return - } - if msgType != websocket.MessageText { - return - } - err = json.Unmarshal(msg, &log) - if err != nil { - return - } - select { - case <-ctx.Done(): - return - case logs <- log: - } - } - }() - return logs, closeFunc(func() error { - <-closed - return nil - }), nil + d := wsjson.NewDecoder[ProvisionerJobLog](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil } // ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index eeb335b130cdd..b4aec16a83190 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -15,6 +15,7 @@ import ( "nhooyr.io/websocket" "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk/wsjson" ) type WorkspaceAgentStatus string @@ -454,30 +455,6 @@ func (c *Client) WorkspaceAgentLogsAfter(ctx context.Context, agentID uuid.UUID, } return nil, nil, ReadBodyAsError(res) } - logChunks := make(chan []WorkspaceAgentLog, 1) - closed := make(chan struct{}) - ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageText) - decoder := json.NewDecoder(wsNetConn) - go func() { - defer close(closed) - defer close(logChunks) - defer conn.Close(websocket.StatusGoingAway, "") - for { - var logs []WorkspaceAgentLog - err = decoder.Decode(&logs) - if err != nil { - return - } - select { - case <-ctx.Done(): - return - case logChunks <- logs: - } - } - }() - return logChunks, closeFunc(func() error { - _ = wsNetConn.Close() - <-closed - return nil - }), nil + d := wsjson.NewDecoder[[]WorkspaceAgentLog](conn, websocket.MessageText, c.logger) + return d.Chan(), d, nil } diff --git a/codersdk/wsjson/decoder.go b/codersdk/wsjson/decoder.go new file mode 100644 index 0000000000000..4cc7ff380a73a --- /dev/null +++ b/codersdk/wsjson/decoder.go @@ -0,0 +1,75 @@ +package wsjson + +import ( + "context" + "encoding/json" + "sync/atomic" + + "nhooyr.io/websocket" + + "cdr.dev/slog" +) + +type Decoder[T any] struct { + conn *websocket.Conn + typ websocket.MessageType + ctx context.Context + cancel context.CancelFunc + chanCalled atomic.Bool + logger slog.Logger +} + +// Chan starts the decoder reading from the websocket and returns a channel for reading the +// resulting values. The chan T is closed if the underlying websocket is closed, or we encounter an +// error. We also close the underlying websocket if we encounter an error reading or decoding. +func (d *Decoder[T]) Chan() <-chan T { + if !d.chanCalled.CompareAndSwap(false, true) { + panic("chan called more than once") + } + values := make(chan T, 1) + go func() { + defer close(values) + defer d.conn.Close(websocket.StatusGoingAway, "") + for { + // we don't use d.ctx here because it only gets canceled after closing the connection + // and a "connection closed" type error is more clear than context canceled. + typ, b, err := d.conn.Read(context.Background()) + if err != nil { + // might be benign like EOF, so just log at debug + d.logger.Debug(d.ctx, "error reading from websocket", slog.Error(err)) + return + } + if typ != d.typ { + d.logger.Error(d.ctx, "websocket type mismatch while decoding") + return + } + var value T + err = json.Unmarshal(b, &value) + if err != nil { + d.logger.Error(d.ctx, "error unmarshalling", slog.Error(err)) + return + } + select { + case values <- value: + // OK + case <-d.ctx.Done(): + return + } + } + }() + return values +} + +// nolint: revive // complains that Encoder has the same function name +func (d *Decoder[T]) Close() error { + err := d.conn.Close(websocket.StatusNormalClosure, "") + d.cancel() + return err +} + +// NewDecoder creates a JSON-over-websocket decoder for type T, which must be deserializable from +// JSON. +func NewDecoder[T any](conn *websocket.Conn, typ websocket.MessageType, logger slog.Logger) *Decoder[T] { + ctx, cancel := context.WithCancel(context.Background()) + return &Decoder[T]{conn: conn, ctx: ctx, cancel: cancel, typ: typ, logger: logger} +} diff --git a/codersdk/wsjson/encoder.go b/codersdk/wsjson/encoder.go new file mode 100644 index 0000000000000..4cde05984e690 --- /dev/null +++ b/codersdk/wsjson/encoder.go @@ -0,0 +1,42 @@ +package wsjson + +import ( + "context" + "encoding/json" + + "golang.org/x/xerrors" + "nhooyr.io/websocket" +) + +type Encoder[T any] struct { + conn *websocket.Conn + typ websocket.MessageType +} + +func (e *Encoder[T]) Encode(v T) error { + w, err := e.conn.Writer(context.Background(), e.typ) + if err != nil { + return xerrors.Errorf("get websocket writer: %w", err) + } + defer w.Close() + j := json.NewEncoder(w) + err = j.Encode(v) + if err != nil { + return xerrors.Errorf("encode json: %w", err) + } + return nil +} + +func (e *Encoder[T]) Close(c websocket.StatusCode) error { + return e.conn.Close(c, "") +} + +// NewEncoder creates a JSON-over websocket encoder for the type T, which must be JSON-serializable. +// You may then call Encode() to send objects over the websocket. Creating an Encoder closes the +// websocket for reading, turning it into a unidirectional write stream of JSON-encoded objects. +func NewEncoder[T any](conn *websocket.Conn, typ websocket.MessageType) *Encoder[T] { + // Here we close the websocket for reading, so that the websocket library will handle pings and + // close frames. + _ = conn.CloseRead(context.Background()) + return &Encoder[T]{conn: conn, typ: typ} +} diff --git a/docs/admin/security/audit-logs.md b/docs/admin/security/audit-logs.md index 3ea4e145d13eb..db214b0e1443e 100644 --- a/docs/admin/security/audit-logs.md +++ b/docs/admin/security/audit-logs.md @@ -24,7 +24,7 @@ We track the following resources: | OAuth2ProviderAppSecret
|
FieldTracked
app_idfalse
created_atfalse
display_secretfalse
hashed_secretfalse
idfalse
last_used_atfalse
secret_prefixfalse
| | Organization
|
FieldTracked
created_atfalse
descriptiontrue
display_nametrue
icontrue
idfalse
is_defaulttrue
nametrue
updated_attrue
| | Template
write, delete |
FieldTracked
active_version_idtrue
activity_bumptrue
allow_user_autostarttrue
allow_user_autostoptrue
allow_user_cancel_workspace_jobstrue
autostart_block_days_of_weektrue
autostop_requirement_days_of_weektrue
autostop_requirement_weekstrue
created_atfalse
created_bytrue
created_by_avatar_urlfalse
created_by_usernamefalse
default_ttltrue
deletedfalse
deprecatedtrue
descriptiontrue
display_nametrue
failure_ttltrue
group_acltrue
icontrue
idtrue
max_port_sharing_leveltrue
nametrue
organization_display_namefalse
organization_iconfalse
organization_idfalse
organization_namefalse
provisionertrue
require_active_versiontrue
time_til_dormanttrue
time_til_dormant_autodeletetrue
updated_atfalse
user_acltrue
| -| TemplateVersion
create, write |
FieldTracked
archivedtrue
created_atfalse
created_bytrue
created_by_avatar_urlfalse
created_by_usernamefalse
external_auth_providersfalse
idtrue
job_idfalse
messagefalse
nametrue
organization_idfalse
readmetrue
template_idtrue
updated_atfalse
| +| TemplateVersion
create, write |
FieldTracked
archivedtrue
created_atfalse
created_bytrue
created_by_avatar_urlfalse
created_by_usernamefalse
external_auth_providersfalse
idtrue
job_idfalse
messagefalse
nametrue
organization_idfalse
readmetrue
source_example_idfalse
template_idtrue
updated_atfalse
| | User
create, write, delete |
FieldTracked
avatar_urlfalse
created_atfalse
deletedtrue
emailtrue
github_com_user_idfalse
hashed_one_time_passcodefalse
hashed_passwordtrue
idtrue
last_seen_atfalse
login_typetrue
nametrue
one_time_passcode_expires_attrue
quiet_hours_scheduletrue
rbac_rolestrue
statustrue
theme_preferencefalse
updated_atfalse
usernametrue
| | WorkspaceBuild
start, stop |
FieldTracked
build_numberfalse
created_atfalse
daily_costfalse
deadlinefalse
idfalse
initiator_by_avatar_urlfalse
initiator_by_usernamefalse
initiator_idfalse
job_idfalse
max_deadlinefalse
provisioner_statefalse
reasonfalse
template_version_idtrue
transitionfalse
updated_atfalse
workspace_idfalse
| | WorkspaceProxy
|
FieldTracked
created_attrue
deletedfalse
derp_enabledtrue
derp_onlytrue
display_nametrue
icontrue
idtrue
nametrue
region_idtrue
token_hashed_secrettrue
updated_atfalse
urltrue
versiontrue
wildcard_hostnametrue
| diff --git a/enterprise/audit/table.go b/enterprise/audit/table.go index f9e74959f2a28..24f7dfa4b4fe0 100644 --- a/enterprise/audit/table.go +++ b/enterprise/audit/table.go @@ -127,6 +127,7 @@ var auditableResourcesTypes = map[any]map[string]Action{ "created_by_avatar_url": ActionIgnore, "created_by_username": ActionIgnore, "archived": ActionTrack, + "source_example_id": ActionIgnore, // Never changes. }, &database.User{}: { "id": ActionTrack, diff --git a/enterprise/dbcrypt/cipher_internal_test.go b/enterprise/dbcrypt/cipher_internal_test.go index b6740de17eec6..c70796ba27e97 100644 --- a/enterprise/dbcrypt/cipher_internal_test.go +++ b/enterprise/dbcrypt/cipher_internal_test.go @@ -3,6 +3,8 @@ package dbcrypt import ( "bytes" "encoding/base64" + "os" + "strings" "testing" "github.com/stretchr/testify/require" @@ -89,3 +91,35 @@ func TestCiphersBackwardCompatibility(t *testing.T) { require.NoError(t, err, "decryption should succeed") require.Equal(t, msg, string(decrypted), "decrypted message should match original message") } + +// If you're looking here, you're probably in trouble. +// Here's what you need to do: +// 1. Get the current CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable. +// 2. Run the following command: +// ENCRYPT_ME="" CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS="" go test -v -count=1 ./enterprise/dbcrypt -test.run='^TestHelpMeEncryptSomeValue$' +// 3. Copy the value from the test output and do what you need with it. +func TestHelpMeEncryptSomeValue(t *testing.T) { + t.Parallel() + t.Skip("this only exists if you need to encrypt a value with dbcrypt, it does not actually test anything") + + valueToEncrypt := os.Getenv("ENCRYPT_ME") + t.Logf("valueToEncrypt: %q", valueToEncrypt) + keys := os.Getenv("CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS") + require.NotEmpty(t, keys, "Set the CODER_EXTERNAL_TOKEN_ENCRYPTION_KEYS environment variable to use this") + + base64Keys := strings.Split(keys, ",") + activeKey := base64Keys[0] + + decodedKey, err := base64.StdEncoding.DecodeString(activeKey) + require.NoError(t, err, "the active key should be valid base64") + + cipher, err := cipherAES256(decodedKey) + require.NoError(t, err) + + t.Logf("cipher digest: %+v", cipher.HexDigest()) + + encryptedEmptyString, err := cipher.Encrypt([]byte(valueToEncrypt)) + require.NoError(t, err) + + t.Logf("encrypted and base64-encoded: %q", base64.StdEncoding.EncodeToString(encryptedEmptyString)) +} diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index 77a7d5cb78738..e0ca58cc5231a 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -261,6 +261,21 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U return link, nil } +func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error { + // We would normally use a sql.NullString here, but sqlc does not want to make + // a params struct with a nullable string. + var digest sql.NullString + if params.OAuthRefreshTokenKeyID != "" { + digest.String = params.OAuthRefreshTokenKeyID + digest.Valid = true + } + if err := db.encryptField(¶ms.OAuthRefreshToken, &digest); err != nil { + return err + } + + return db.Store.UpdateExternalAuthLinkRefreshToken(ctx, params) +} + func (db *dbCrypt) GetCryptoKeys(ctx context.Context) ([]database.CryptoKey, error) { keys, err := db.Store.GetCryptoKeys(ctx) if err != nil { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 3e252496d6a69..e73c3eee85c16 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -17,6 +17,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/database/dbtime" ) func TestUserLinks(t *testing.T) { @@ -96,6 +97,31 @@ func TestUserLinks(t *testing.T) { require.EqualValues(t, expectedClaims, rawLink.Claims) }) + t.Run("UpdateExternalAuthLinkRefreshToken", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + user := dbgen.User(t, crypt, database.User{}) + link := dbgen.ExternalAuthLink(t, crypt, database.ExternalAuthLink{ + UserID: user.ID, + }) + + err := crypt.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{ + OAuthRefreshToken: "", + OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID.String, + UpdatedAt: dbtime.Now(), + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + + rawLink, err := db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{ + ProviderID: link.ProviderID, + UserID: link.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawLink.OAuthRefreshToken, "") + }) + t.Run("GetUserLinkByLinkedID", func(t *testing.T) { t.Parallel() t.Run("OK", func(t *testing.T) { diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 3ad195f2bd9e4..cfba27408e9c6 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -682,12 +682,20 @@ class ApiMethods { /** * @param organization Can be the organization's ID or name + * @param tags to filter provisioner daemons by. */ getProvisionerDaemonsByOrganization = async ( organization: string, + tags?: Record, ): Promise => { + const params = new URLSearchParams(); + + if (tags) { + params.append("tags", JSON.stringify(tags)); + } + const response = await this.axios.get( - `/api/v2/organizations/${organization}/provisionerdaemons`, + `/api/v2/organizations/${organization}/provisionerdaemons?${params.toString()}`, ); return response.data; }; diff --git a/site/src/api/queries/organizations.ts b/site/src/api/queries/organizations.ts index d1df8f409dcdf..c3f5a4ebd3ced 100644 --- a/site/src/api/queries/organizations.ts +++ b/site/src/api/queries/organizations.ts @@ -115,16 +115,18 @@ export const organizations = () => { }; }; -export const getProvisionerDaemonsKey = (organization: string) => [ - "organization", - organization, - "provisionerDaemons", -]; +export const getProvisionerDaemonsKey = ( + organization: string, + tags?: Record, +) => ["organization", organization, tags, "provisionerDaemons"]; -export const provisionerDaemons = (organization: string) => { +export const provisionerDaemons = ( + organization: string, + tags?: Record, +) => { return { - queryKey: getProvisionerDaemonsKey(organization), - queryFn: () => API.getProvisionerDaemonsByOrganization(organization), + queryKey: getProvisionerDaemonsKey(organization, tags), + queryFn: () => API.getProvisionerDaemonsByOrganization(organization, tags), }; }; diff --git a/site/src/components/Alert/Alert.tsx b/site/src/components/Alert/Alert.tsx index df741a1924fa9..7750a6bc7d1e8 100644 --- a/site/src/components/Alert/Alert.tsx +++ b/site/src/components/Alert/Alert.tsx @@ -1,4 +1,5 @@ import MuiAlert, { + type AlertColor as MuiAlertColor, type AlertProps as MuiAlertProps, // biome-ignore lint/nursery/noRestrictedImports: Used as base component } from "@mui/material/Alert"; @@ -11,6 +12,8 @@ import { useState, } from "react"; +export type AlertColor = MuiAlertColor; + export type AlertProps = MuiAlertProps & { actions?: ReactNode; dismissible?: boolean; diff --git a/site/src/modules/provisioners/ProvisionerAlert.stories.tsx b/site/src/modules/provisioners/ProvisionerAlert.stories.tsx new file mode 100644 index 0000000000000..d9ca1501d6611 --- /dev/null +++ b/site/src/modules/provisioners/ProvisionerAlert.stories.tsx @@ -0,0 +1,28 @@ +import type { Meta, StoryObj } from "@storybook/react"; +import { chromatic } from "testHelpers/chromatic"; +import { ProvisionerAlert } from "./ProvisionerAlert"; + +const meta: Meta = { + title: "modules/provisioners/ProvisionerAlert", + parameters: { + chromatic, + layout: "centered", + }, + component: ProvisionerAlert, + args: { + title: "Title", + detail: "Detail", + severity: "info", + tags: { tag: "tagValue" }, + }, +}; + +export default meta; +type Story = StoryObj; + +export const Info: Story = {}; +export const NullTags: Story = { + args: { + tags: undefined, + }, +}; diff --git a/site/src/modules/provisioners/ProvisionerAlert.tsx b/site/src/modules/provisioners/ProvisionerAlert.tsx new file mode 100644 index 0000000000000..54d9ab8473e87 --- /dev/null +++ b/site/src/modules/provisioners/ProvisionerAlert.tsx @@ -0,0 +1,45 @@ +import AlertTitle from "@mui/material/AlertTitle"; +import { Alert, type AlertColor } from "components/Alert/Alert"; +import { AlertDetail } from "components/Alert/Alert"; +import { Stack } from "components/Stack/Stack"; +import { ProvisionerTag } from "modules/provisioners/ProvisionerTag"; +import type { FC } from "react"; +interface ProvisionerAlertProps { + title: string; + detail: string; + severity: AlertColor; + tags: Record; +} + +export const ProvisionerAlert: FC = ({ + title, + detail, + severity, + tags, +}) => { + return ( + { + return { + borderRadius: 0, + border: 0, + borderBottom: `1px solid ${theme.palette.divider}`, + borderLeft: `2px solid ${theme.palette[severity].main}`, + }; + }} + > + {title} + +
{detail}
+ + {Object.entries(tags ?? {}) + .filter(([key]) => key !== "owner") + .map(([key, value]) => ( + + ))} + +
+
+ ); +}; diff --git a/site/src/modules/provisioners/ProvisionerStatusAlert.stories.tsx b/site/src/modules/provisioners/ProvisionerStatusAlert.stories.tsx new file mode 100644 index 0000000000000..d4f746e99c417 --- /dev/null +++ b/site/src/modules/provisioners/ProvisionerStatusAlert.stories.tsx @@ -0,0 +1,55 @@ +import type { Meta, StoryObj } from "@storybook/react"; +import { chromatic } from "testHelpers/chromatic"; +import { MockTemplateVersion } from "testHelpers/entities"; +import { ProvisionerStatusAlert } from "./ProvisionerStatusAlert"; + +const meta: Meta = { + title: "modules/provisioners/ProvisionerStatusAlert", + parameters: { + chromatic, + layout: "centered", + }, + component: ProvisionerStatusAlert, + args: { + matchingProvisioners: 0, + availableProvisioners: 0, + tags: MockTemplateVersion.job.tags, + }, +}; + +export default meta; +type Story = StoryObj; + +export const HealthyProvisioners: Story = { + args: { + matchingProvisioners: 1, + availableProvisioners: 1, + }, +}; + +export const UndefinedMatchingProvisioners: Story = { + args: { + matchingProvisioners: undefined, + availableProvisioners: undefined, + }, +}; + +export const UndefinedAvailableProvisioners: Story = { + args: { + matchingProvisioners: 1, + availableProvisioners: undefined, + }, +}; + +export const NoMatchingProvisioners: Story = { + args: { + matchingProvisioners: 0, + }, +}; + +export const NoAvailableProvisioners: Story = { + args: { + matchingProvisioners: 1, + availableProvisioners: 0, + }, +}; diff --git a/site/src/modules/provisioners/ProvisionerStatusAlert.tsx b/site/src/modules/provisioners/ProvisionerStatusAlert.tsx new file mode 100644 index 0000000000000..54a2b56704877 --- /dev/null +++ b/site/src/modules/provisioners/ProvisionerStatusAlert.tsx @@ -0,0 +1,47 @@ +import type { AlertColor } from "components/Alert/Alert"; +import type { FC } from "react"; +import { ProvisionerAlert } from "./ProvisionerAlert"; + +interface ProvisionerStatusAlertProps { + matchingProvisioners: number | undefined; + availableProvisioners: number | undefined; + tags: Record; +} + +export const ProvisionerStatusAlert: FC = ({ + matchingProvisioners, + availableProvisioners, + tags, +}) => { + let title: string; + let detail: string; + let severity: AlertColor; + switch (true) { + case matchingProvisioners === 0: + title = "Build pending provisioner deployment"; + detail = + "Your build has been enqueued, but there are no provisioners that accept the required tags. Once a compatible provisioner becomes available, your build will continue. Please contact your administrator."; + severity = "warning"; + break; + case availableProvisioners === 0: + title = "Build delayed"; + detail = + "Provisioners that accept the required tags have not responded for longer than expected. This may delay your build. Please contact your administrator if your build does not complete."; + severity = "warning"; + break; + default: + title = "Build enqueued"; + detail = + "Your build has been enqueued and will begin once a provisioner becomes available to process it."; + severity = "info"; + } + + return ( + + ); +}; diff --git a/site/src/pages/CreateTemplatePage/BuildLogsDrawer.stories.tsx b/site/src/pages/CreateTemplatePage/BuildLogsDrawer.stories.tsx index afc3c1321a6b4..29229fadfd0ad 100644 --- a/site/src/pages/CreateTemplatePage/BuildLogsDrawer.stories.tsx +++ b/site/src/pages/CreateTemplatePage/BuildLogsDrawer.stories.tsx @@ -34,6 +34,42 @@ export const MissingVariables: Story = { }, }; +export const NoProvisioners: Story = { + args: { + templateVersion: { + ...MockTemplateVersion, + matched_provisioners: { + count: 0, + available: 0, + }, + }, + }, +}; + +export const ProvisionersUnhealthy: Story = { + args: { + templateVersion: { + ...MockTemplateVersion, + matched_provisioners: { + count: 1, + available: 0, + }, + }, + }, +}; + +export const ProvisionersHealthy: Story = { + args: { + templateVersion: { + ...MockTemplateVersion, + matched_provisioners: { + count: 1, + available: 1, + }, + }, + }, +}; + export const Logs: Story = { args: { templateVersion: { diff --git a/site/src/pages/CreateTemplatePage/BuildLogsDrawer.tsx b/site/src/pages/CreateTemplatePage/BuildLogsDrawer.tsx index 5af38b649c695..4eb1805b60e36 100644 --- a/site/src/pages/CreateTemplatePage/BuildLogsDrawer.tsx +++ b/site/src/pages/CreateTemplatePage/BuildLogsDrawer.tsx @@ -8,6 +8,7 @@ import { visuallyHidden } from "@mui/utils"; import { JobError } from "api/queries/templates"; import type { TemplateVersion } from "api/typesGenerated"; import { Loader } from "components/Loader/Loader"; +import { ProvisionerStatusAlert } from "modules/provisioners/ProvisionerStatusAlert"; import { useWatchVersionLogs } from "modules/templates/useWatchVersionLogs"; import { WorkspaceBuildLogs } from "modules/workspaces/WorkspaceBuildLogs/WorkspaceBuildLogs"; import { type FC, useLayoutEffect, useRef } from "react"; @@ -27,6 +28,10 @@ export const BuildLogsDrawer: FC = ({ variablesSectionRef, ...drawerProps }) => { + const matchingProvisioners = templateVersion?.matched_provisioners?.count; + const availableProvisioners = + templateVersion?.matched_provisioners?.available; + const logs = useWatchVersionLogs(templateVersion); const logsContainer = useRef(null); @@ -65,6 +70,8 @@ export const BuildLogsDrawer: FC = ({ + {} + {isMissingVariables ? ( { @@ -82,7 +89,14 @@ export const BuildLogsDrawer: FC = ({ ) : ( - + <> + + + )} diff --git a/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.stories.tsx b/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.stories.tsx index 9147a1a5befff..05ed426d5dcc9 100644 --- a/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.stories.tsx +++ b/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.stories.tsx @@ -42,6 +42,7 @@ const meta: Meta = { deploymentDAUs: MockDeploymentDAUResponse, invalidExperiments: [], safeExperiments: [], + entitlements: undefined, }, }; @@ -136,3 +137,74 @@ export const invalidExperimentsEnabled: Story = { invalidExperiments: ["invalid"], }, }; + +export const WithLicenseUtilization: Story = { + args: { + entitlements: { + ...MockEntitlementsWithUserLimit, + features: { + ...MockEntitlementsWithUserLimit.features, + user_limit: { + ...MockEntitlementsWithUserLimit.features.user_limit, + enabled: true, + actual: 75, + limit: 100, + entitlement: "entitled", + }, + }, + }, + }, +}; + +export const HighLicenseUtilization: Story = { + args: { + entitlements: { + ...MockEntitlementsWithUserLimit, + features: { + ...MockEntitlementsWithUserLimit.features, + user_limit: { + ...MockEntitlementsWithUserLimit.features.user_limit, + enabled: true, + actual: 95, + limit: 100, + entitlement: "entitled", + }, + }, + }, + }, +}; + +export const ExceedsLicenseUtilization: Story = { + args: { + entitlements: { + ...MockEntitlementsWithUserLimit, + features: { + ...MockEntitlementsWithUserLimit.features, + user_limit: { + ...MockEntitlementsWithUserLimit.features.user_limit, + enabled: true, + actual: 100, + limit: 95, + entitlement: "entitled", + }, + }, + }, + }, +}; +export const NoLicenseLimit: Story = { + args: { + entitlements: { + ...MockEntitlementsWithUserLimit, + features: { + ...MockEntitlementsWithUserLimit.features, + user_limit: { + ...MockEntitlementsWithUserLimit.features.user_limit, + enabled: false, + actual: 0, + limit: 0, + entitlement: "entitled", + }, + }, + }, + }, +}; diff --git a/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.tsx b/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.tsx index 29edacd08d9e7..df5550d70e965 100644 --- a/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.tsx +++ b/site/src/pages/DeploymentSettingsPage/GeneralSettingsPage/GeneralSettingsPageView.tsx @@ -1,4 +1,5 @@ import AlertTitle from "@mui/material/AlertTitle"; +import LinearProgress from "@mui/material/LinearProgress"; import type { DAUsResponse, Entitlements, @@ -36,6 +37,12 @@ export const GeneralSettingsPageView: FC = ({ safeExperiments, invalidExperiments, }) => { + const licenseUtilizationPercentage = + entitlements?.features?.user_limit?.actual && + entitlements?.features?.user_limit?.limit + ? entitlements.features.user_limit.actual / + entitlements.features.user_limit.limit + : undefined; return ( <> = ({ )} + {licenseUtilizationPercentage && ( + + + + {Math.round(licenseUtilizationPercentage * 100)}% used ( + {entitlements!.features.user_limit.actual}/ + {entitlements!.features.user_limit.limit} users) + + + )} {invalidExperiments.length > 0 && ( Invalid experiments in use: diff --git a/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.stories.tsx b/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.stories.tsx index 1382aa100a1dc..4b8413215c9e8 100644 --- a/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.stories.tsx +++ b/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.stories.tsx @@ -49,6 +49,73 @@ type Story = StoryObj; export const Example: Story = {}; +export const UndefinedLogs: Story = { + args: { + defaultTab: "logs", + buildLogs: undefined, + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + }, + }, +}; + +export const EmptyLogs: Story = { + args: { + defaultTab: "logs", + buildLogs: [], + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + }, + }, +}; + +export const NoProvisioners: Story = { + args: { + defaultTab: "logs", + buildLogs: [], + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + matched_provisioners: { + count: 0, + available: 0, + }, + }, + }, +}; + +export const UnavailableProvisioners: Story = { + args: { + defaultTab: "logs", + buildLogs: [], + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + matched_provisioners: { + count: 1, + available: 0, + }, + }, + }, +}; + +export const HealthyProvisioners: Story = { + args: { + defaultTab: "logs", + buildLogs: [], + templateVersion: { + ...MockTemplateVersion, + job: MockRunningProvisionerJob, + matched_provisioners: { + count: 1, + available: 1, + }, + }, + }, +}; + export const Logs: Story = { args: { defaultTab: "logs", diff --git a/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.tsx b/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.tsx index 943370f89e2a4..858f57dd59493 100644 --- a/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.tsx +++ b/site/src/pages/TemplateVersionEditorPage/TemplateVersionEditor.tsx @@ -4,7 +4,6 @@ import ArrowBackOutlined from "@mui/icons-material/ArrowBackOutlined"; import CloseOutlined from "@mui/icons-material/CloseOutlined"; import PlayArrowOutlined from "@mui/icons-material/PlayArrowOutlined"; import WarningOutlined from "@mui/icons-material/WarningOutlined"; -import AlertTitle from "@mui/material/AlertTitle"; import Button from "@mui/material/Button"; import ButtonGroup from "@mui/material/ButtonGroup"; import IconButton from "@mui/material/IconButton"; @@ -17,7 +16,7 @@ import type { VariableValue, WorkspaceResource, } from "api/typesGenerated"; -import { Alert, AlertDetail } from "components/Alert/Alert"; +import { Alert } from "components/Alert/Alert"; import { Sidebar } from "components/FullPageLayout/Sidebar"; import { Topbar, @@ -29,6 +28,8 @@ import { } from "components/FullPageLayout/Topbar"; import { Loader } from "components/Loader/Loader"; import { linkToTemplate, useLinks } from "modules/navigation"; +import { ProvisionerAlert } from "modules/provisioners/ProvisionerAlert"; +import { ProvisionerStatusAlert } from "modules/provisioners/ProvisionerStatusAlert"; import { TemplateFileTree } from "modules/templates/TemplateFiles/TemplateFileTree"; import { isBinaryData } from "modules/templates/TemplateFiles/isBinaryData"; import { TemplateResourcesTable } from "modules/templates/TemplateResourcesTable/TemplateResourcesTable"; @@ -126,6 +127,8 @@ export const TemplateVersionEditor: FC = ({ const [deleteFileOpen, setDeleteFileOpen] = useState(); const [renameFileOpen, setRenameFileOpen] = useState(); const [dirty, setDirty] = useState(false); + const matchingProvisioners = templateVersion.matched_provisioners?.count; + const availableProvisioners = templateVersion.matched_provisioners?.available; const triggerPreview = useCallback(async () => { await onPreview(fileTree); @@ -192,6 +195,8 @@ export const TemplateVersionEditor: FC = ({ linkToTemplate(template.organization_name, template.name), ); + const gotBuildLogs = buildLogs && buildLogs.length > 0; + return ( <>
@@ -581,31 +586,34 @@ export const TemplateVersionEditor: FC = ({ css={[styles.logs, styles.tabContent]} ref={logsContentRef} > - {templateVersion.job.error && ( + {templateVersion.job.error ? (
- - Error during the build - {templateVersion.job.error} - + tags={templateVersion.job.tags} + />
+ ) : ( + !gotBuildLogs && ( + <> + + + + ) )} - {buildLogs && buildLogs.length > 0 ? ( + {gotBuildLogs && ( - ) : ( - )}
)} diff --git a/site/src/testHelpers/storybook.tsx b/site/src/testHelpers/storybook.tsx index e905a9b412c2c..514d34e0265e8 100644 --- a/site/src/testHelpers/storybook.tsx +++ b/site/src/testHelpers/storybook.tsx @@ -1,6 +1,7 @@ import type { StoryContext } from "@storybook/react"; import { withDefaultFeatures } from "api/api"; import { getAuthorizationKey } from "api/queries/authCheck"; +import { getProvisionerDaemonsKey } from "api/queries/organizations"; import { hasFirstUserKey, meKey } from "api/queries/users"; import type { Entitlements } from "api/typesGenerated"; import { GlobalSnackbar } from "components/GlobalSnackbar/GlobalSnackbar"; @@ -121,6 +122,30 @@ export const withAuthProvider = (Story: FC, { parameters }: StoryContext) => { ); }; +export const withProvisioners = (Story: FC, { parameters }: StoryContext) => { + if (!parameters.organization_id) { + throw new Error( + "You forgot to add `parameters.organization_id` to your story", + ); + } + if (!parameters.provisioners) { + throw new Error( + "You forgot to add `parameters.provisioners` to your story", + ); + } + if (!parameters.tags) { + throw new Error("You forgot to add `parameters.tags` to your story"); + } + + const queryClient = useQueryClient(); + queryClient.setQueryData( + getProvisionerDaemonsKey(parameters.organization_id, parameters.tags), + parameters.provisioners, + ); + + return ; +}; + export const withGlobalSnackbar = (Story: FC) => ( <>