diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 41fa20392fadf..8bfede13e8c72 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -575,11 +575,6 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, r return nil } -func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { - // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. - return q.GetTemplatesWithFilter(ctx, arg) -} - func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) error { deleteF := func(ctx context.Context, id uuid.UUID) error { return q.db.UpdateTemplateDeletedByID(ctx, database.UpdateTemplateDeletedByIDParams{ @@ -591,34 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id) } -func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - // An actor is authorized to read template group roles if they are authorized to read the template. - template, err := q.db.GetTemplateByID(ctx, id) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - return q.db.GetTemplateGroupRoles(ctx, id) -} - -func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - // An actor is authorized to query template user roles if they are authorized to read the template. - template, err := q.db.GetTemplateByID(ctx, id) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - return q.db.GetTemplateUserRoles(ctx, id) -} - -func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - return q.db.GetAuthorizedUserCount(ctx, arg, prepared) -} - func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { // TODO Implement this with a SQL filter. The count is incorrect without it. rowUsers, err := q.db.GetUsers(ctx, arg) @@ -655,11 +622,6 @@ func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetUserByID, deleteF)(ctx, id) } -func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. - return q.GetWorkspaces(ctx, arg) -} - func (q *querier) SoftDeleteWorkspaceByID(ctx context.Context, id uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetWorkspaceByID, func(ctx context.Context, id uuid.UUID) error { return q.db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ @@ -2642,3 +2604,41 @@ func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (d } return q.db.UpsertTailnetCoordinator(ctx, id) } + +func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { + // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. + return q.GetTemplatesWithFilter(ctx, arg) +} + +func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + // An actor is authorized to read template group roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateGroupRoles(ctx, id) +} + +func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + // An actor is authorized to query template user roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateUserRoles(ctx, id) +} + +func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. + return q.GetWorkspaces(ctx, arg) +} + +func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + return q.db.GetAuthorizedUserCount(ctx, arg, prepared) +} diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index bdcfd9366dabb..bf2409de28b54 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -266,80 +266,6 @@ func (q *FakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) { return database.User{}, sql.ErrNoRows } -func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - if err := validateDatabaseType(params); err != nil { - return 0, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return -1, err - } - } - - users := make([]database.User, 0, len(q.users)) - - for _, user := range q.users { - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { - continue - } - - users = append(users, user) - } - - // Filter out deleted since they should never be returned.. - tmp := make([]database.User, 0, len(users)) - for _, user := range users { - if !user.Deleted { - tmp = append(tmp, user) - } - } - users = tmp - - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } - } - users = tmp - } - - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } - } - users = usersFilteredByStatus - } - - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } - } - - users = usersFilteredByRole - } - - return int64(len(users)), nil -} - func convertUsers(users []database.User, count int64) []database.GetUsersRow { rows := make([]database.GetUsersRow, len(users)) for i, u := range users { @@ -363,567 +289,176 @@ func convertUsers(users []database.User, count int64) []database.GetUsersRow { return rows } -//nolint:gocyclo -func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err +// mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc. +// The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent. +func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string { + var status string + connectionTimeout := time.Duration(dbAgent.ConnectionTimeoutSeconds) * time.Second + switch { + case !dbAgent.FirstConnectedAt.Valid: + switch { + case connectionTimeout > 0 && database.Now().Sub(dbAgent.CreatedAt) > connectionTimeout: + // If the agent took too long to connect the first time, + // mark it as timed out. + status = "timeout" + default: + // If the agent never connected, it's waiting for the compute + // to start up. + status = "connecting" } + case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time): + // If we've disconnected after our last connection, we know the + // agent is no longer connected. + status = "disconnected" + case database.Now().Sub(dbAgent.LastConnectedAt.Time) > time.Duration(agentInactiveDisconnectTimeoutSeconds)*time.Second: + // The connection died without updating the last connected. + status = "disconnected" + case dbAgent.LastConnectedAt.Valid: + // The agent should be assumed connected if it's under inactivity timeouts + // and last connected at has been properly set. + status = "connected" + default: + panic("unknown agent status: " + status) } + return status +} - workspaces := make([]database.Workspace, 0) - for _, workspace := range q.workspaces { - if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { - continue +func (q *FakeQuerier) convertToWorkspaceRowsNoLock(ctx context.Context, workspaces []database.Workspace, count int64) []database.GetWorkspacesRow { + rows := make([]database.GetWorkspacesRow, 0, len(workspaces)) + for _, w := range workspaces { + wr := database.GetWorkspacesRow{ + ID: w.ID, + CreatedAt: w.CreatedAt, + UpdatedAt: w.UpdatedAt, + OwnerID: w.OwnerID, + OrganizationID: w.OrganizationID, + TemplateID: w.TemplateID, + Deleted: w.Deleted, + Name: w.Name, + AutostartSchedule: w.AutostartSchedule, + Ttl: w.Ttl, + LastUsedAt: w.LastUsedAt, + Count: count, } - if arg.OwnerUsername != "" { - owner, err := q.getUserByIDNoLock(workspace.OwnerID) - if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { - continue + for _, t := range q.templates { + if t.ID == w.TemplateID { + wr.TemplateName = t.Name + break } } - if arg.TemplateName != "" { - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) { - continue + if build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID); err == nil { + for _, tv := range q.templateVersions { + if tv.ID == build.TemplateVersionID { + wr.TemplateVersionID = tv.ID + wr.TemplateVersionName = sql.NullString{ + Valid: true, + String: tv.Name, + } + break + } } } - if !arg.Deleted && workspace.Deleted { - continue + rows = append(rows, wr) + } + return rows +} + +func (q *FakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) { + for _, workspace := range q.workspaces { + if workspace.ID == id { + return workspace, nil } + } + return database.Workspace{}, sql.ErrNoRows +} - if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) { - continue +func (q *FakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { + var agent database.WorkspaceAgent + for _, _agent := range q.workspaceAgents { + if _agent.ID == agentID { + agent = _agent + break } + } + if agent.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows + } - if arg.Status != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } + var resource database.WorkspaceResource + for _, _resource := range q.workspaceResources { + if _resource.ID == agent.ResourceID { + resource = _resource + break + } + } + if resource.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows + } - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } + var build database.WorkspaceBuild + for _, _build := range q.workspaceBuilds { + if _build.JobID == resource.JobID { + build = _build + break + } + } + if build.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows + } - // This logic should match the logic in the workspace.sql file. - var statusMatch bool - switch database.WorkspaceStatus(arg.Status) { - case database.WorkspaceStatusPending: - statusMatch = isNull(job.StartedAt) - case database.WorkspaceStatusStarting: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionStart + for _, workspace := range q.workspaces { + if workspace.ID == build.WorkspaceID { + return workspace, nil + } + } - case database.WorkspaceStatusRunning: - statusMatch = isNotNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionStart + return database.Workspace{}, sql.ErrNoRows +} - case database.WorkspaceStatusStopping: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionStop - - case database.WorkspaceStatusStopped: - statusMatch = isNotNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionStop - case database.WorkspaceStatusFailed: - statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) || - (isNotNull(job.CompletedAt) && isNotNull(job.Error)) - - case database.WorkspaceStatusCanceling: - statusMatch = isNotNull(job.CanceledAt) && - isNull(job.CompletedAt) - - case database.WorkspaceStatusCanceled: - statusMatch = isNotNull(job.CanceledAt) && - isNotNull(job.CompletedAt) - - case database.WorkspaceStatusDeleted: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNotNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionDelete && - isNull(job.Error) - - case database.WorkspaceStatusDeleting: - statusMatch = isNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionDelete - - default: - return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status) - } - if !statusMatch { - continue - } - } - - if arg.HasAgent != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - - var workspaceResourceIDs []uuid.UUID - for _, wr := range workspaceResources { - workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) - } - - workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) - if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - - var hasAgentMatched bool - for _, wa := range workspaceAgents { - if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent { - hasAgentMatched = true - } - } - - if !hasAgentMatched { - continue - } - } - - if len(arg.TemplateIds) > 0 { - match := false - for _, id := range arg.TemplateIds { - if workspace.TemplateID == id { - match = true - break - } - } - if !match { - continue - } - } - - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { - continue - } - workspaces = append(workspaces, workspace) - } - - // Sort workspaces (ORDER BY) - isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool { - return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart - } - - preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{} - preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{} - preloadedUsers := map[uuid.UUID]database.User{} - - for _, w := range workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) - if err == nil { - preloadedWorkspaceBuilds[w.ID] = build - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err == nil { - preloadedProvisionerJobs[w.ID] = job - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - user, err := q.getUserByIDNoLock(w.OwnerID) - if err == nil { - preloadedUsers[w.ID] = user - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user: %w", err) - } - } - - sort.Slice(workspaces, func(i, j int) bool { - w1 := workspaces[i] - w2 := workspaces[j] - - // Order by: running first - w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID]) - w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID]) - - if w1IsRunning && !w2IsRunning { - return true - } - - if !w1IsRunning && w2IsRunning { - return false - } - - // Order by: usernames - if w1.ID != w2.ID { - return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username}) - } - - // Order by: workspace names - return sort.StringsAreSorted([]string{w1.Name, w2.Name}) - }) - - beforePageCount := len(workspaces) - - if arg.Offset > 0 { - if int(arg.Offset) > len(workspaces) { - return []database.GetWorkspacesRow{}, nil - } - workspaces = workspaces[arg.Offset:] - } - if arg.Limit > 0 { - if int(arg.Limit) > len(workspaces) { - return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil - } - workspaces = workspaces[:arg.Limit] - } - - return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil -} - -// mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc. -// The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent. -func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string { - var status string - connectionTimeout := time.Duration(dbAgent.ConnectionTimeoutSeconds) * time.Second - switch { - case !dbAgent.FirstConnectedAt.Valid: - switch { - case connectionTimeout > 0 && database.Now().Sub(dbAgent.CreatedAt) > connectionTimeout: - // If the agent took too long to connect the first time, - // mark it as timed out. - status = "timeout" - default: - // If the agent never connected, it's waiting for the compute - // to start up. - status = "connecting" - } - case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time): - // If we've disconnected after our last connection, we know the - // agent is no longer connected. - status = "disconnected" - case database.Now().Sub(dbAgent.LastConnectedAt.Time) > time.Duration(agentInactiveDisconnectTimeoutSeconds)*time.Second: - // The connection died without updating the last connected. - status = "disconnected" - case dbAgent.LastConnectedAt.Valid: - // The agent should be assumed connected if it's under inactivity timeouts - // and last connected at has been properly set. - status = "connected" - default: - panic("unknown agent status: " + status) - } - return status -} - -func (q *FakeQuerier) convertToWorkspaceRowsNoLock(ctx context.Context, workspaces []database.Workspace, count int64) []database.GetWorkspacesRow { - rows := make([]database.GetWorkspacesRow, 0, len(workspaces)) - for _, w := range workspaces { - wr := database.GetWorkspacesRow{ - ID: w.ID, - CreatedAt: w.CreatedAt, - UpdatedAt: w.UpdatedAt, - OwnerID: w.OwnerID, - OrganizationID: w.OrganizationID, - TemplateID: w.TemplateID, - Deleted: w.Deleted, - Name: w.Name, - AutostartSchedule: w.AutostartSchedule, - Ttl: w.Ttl, - LastUsedAt: w.LastUsedAt, - Count: count, - } - - for _, t := range q.templates { - if t.ID == w.TemplateID { - wr.TemplateName = t.Name - break - } - } - - if build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID); err == nil { - for _, tv := range q.templateVersions { - if tv.ID == build.TemplateVersionID { - wr.TemplateVersionID = tv.ID - wr.TemplateVersionName = sql.NullString{ - Valid: true, - String: tv.Name, - } - break - } - } - } - - rows = append(rows, wr) - } - return rows -} - -func (q *FakeQuerier) getWorkspaceByIDNoLock(_ context.Context, id uuid.UUID) (database.Workspace, error) { - for _, workspace := range q.workspaces { - if workspace.ID == id { - return workspace, nil - } - } - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceByAgentIDNoLock(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { - var agent database.WorkspaceAgent - for _, _agent := range q.workspaceAgents { - if _agent.ID == agentID { - agent = _agent - break - } - } - if agent.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - var resource database.WorkspaceResource - for _, _resource := range q.workspaceResources { - if _resource.ID == agent.ResourceID { - resource = _resource - break - } - } - if resource.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - var build database.WorkspaceBuild - for _, _build := range q.workspaceBuilds { - if _build.JobID == resource.JobID { - build = _build - break - } - } - if build.ID == uuid.Nil { - return database.Workspace{}, sql.ErrNoRows - } - - for _, workspace := range q.workspaces { - if workspace.ID == build.WorkspaceID { - return workspace, nil - } - } - - return database.Workspace{}, sql.ErrNoRows -} - -func (q *FakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { - for _, history := range q.workspaceBuilds { - if history.ID == id { - return history, nil - } - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} +func (q *FakeQuerier) getWorkspaceBuildByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceBuild, error) { + for _, history := range q.workspaceBuilds { + if history.ID == id { + return history, nil + } + } + return database.WorkspaceBuild{}, sql.ErrNoRows +} func (q *FakeQuerier) getLatestWorkspaceBuildByWorkspaceIDNoLock(_ context.Context, workspaceID uuid.UUID) (database.WorkspaceBuild, error) { - var row database.WorkspaceBuild - var buildNum int32 = -1 - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.WorkspaceID == workspaceID && workspaceBuild.BuildNumber > buildNum { - row = workspaceBuild - buildNum = workspaceBuild.BuildNumber - } - } - if buildNum == -1 { - return database.WorkspaceBuild{}, sql.ErrNoRows - } - return row, nil -} - -func (q *FakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (database.Template, error) { - for _, template := range q.templates { - if template.ID == id { - return template.DeepCopy(), nil - } - } - return database.Template{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) - if err != nil { - return nil, err - } - } - - var templates []database.Template - for _, template := range q.templates { - if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { - continue - } - - if template.Deleted != arg.Deleted { - continue - } - if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { - continue - } - - if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { - continue - } - - if len(arg.IDs) > 0 { - match := false - for _, id := range arg.IDs { - if template.ID == id { - match = true - break - } - } - if !match { - continue - } - } - templates = append(templates, template.DeepCopy()) - } - if len(templates) > 0 { - slices.SortFunc(templates, func(i, j database.Template) bool { - if i.Name != j.Name { - return i.Name < j.Name - } - return i.ID.String() < j.ID.String() - }) - return templates, nil - } - - return nil, sql.ErrNoRows -} - -func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { - for _, templateVersion := range q.templateVersions { - if templateVersion.ID != templateVersionID { - continue - } - return templateVersion, nil - } - return database.TemplateVersion{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var template database.Template - for _, t := range q.templates { - if t.ID == id { - template = t - break - } - } - - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows - } - - users := make([]database.TemplateUser, 0, len(template.UserACL)) - for k, v := range template.UserACL { - user, err := q.getUserByIDNoLock(uuid.MustParse(k)) - if err != nil && xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user by ID: %w", err) - } - // We don't delete users from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { - continue - } - - if user.Deleted || user.Status == database.UserStatusSuspended { - continue + var row database.WorkspaceBuild + var buildNum int32 = -1 + for _, workspaceBuild := range q.workspaceBuilds { + if workspaceBuild.WorkspaceID == workspaceID && workspaceBuild.BuildNumber > buildNum { + row = workspaceBuild + buildNum = workspaceBuild.BuildNumber } - - users = append(users, database.TemplateUser{ - User: user, - Actions: v, - }) } - - return users, nil + if buildNum == -1 { + return database.WorkspaceBuild{}, sql.ErrNoRows + } + return row, nil } -func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var template database.Template - for _, t := range q.templates { - if t.ID == id { - template = t - break +func (q *FakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (database.Template, error) { + for _, template := range q.templates { + if template.ID == id { + return template.DeepCopy(), nil } } + return database.Template{}, sql.ErrNoRows +} - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows - } - - groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) - for k, v := range template.GroupACL { - group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get group by ID: %w", err) - } - // We don't delete groups from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { +func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { + for _, templateVersion := range q.templateVersions { + if templateVersion.ID != templateVersionID { continue } - - groups = append(groups, database.TemplateGroup{ - Group: group, - Actions: v, - }) + return templateVersion, nil } - - return groups, nil + return database.TemplateVersion{}, sql.ErrNoRows } func (q *FakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { @@ -3254,609 +2789,1141 @@ func (q *FakeQuerier) GetWorkspaceBuildsByWorkspaceID(_ context.Context, if len(history) == 0 { return nil, sql.ErrNoRows } - return history, nil + return history, nil +} + +func (q *FakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceBuild, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + workspaceBuilds := make([]database.WorkspaceBuild, 0) + for _, workspaceBuild := range q.workspaceBuilds { + if workspaceBuild.CreatedAt.After(after) { + workspaceBuilds = append(workspaceBuilds, workspaceBuild) + } + } + return workspaceBuilds, nil +} + +func (q *FakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + return q.getWorkspaceByAgentIDNoLock(ctx, agentID) +} + +func (q *FakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + return q.getWorkspaceByIDNoLock(ctx, id) +} + +func (q *FakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Workspace{}, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + var found *database.Workspace + for _, workspace := range q.workspaces { + workspace := workspace + if workspace.OwnerID != arg.OwnerID { + continue + } + if !strings.EqualFold(workspace.Name, arg.Name) { + continue + } + if workspace.Deleted != arg.Deleted { + continue + } + + // Return the most recent workspace with the given name + if found == nil || workspace.CreatedAt.After(found.CreatedAt) { + found = &workspace + } + } + if found != nil { + return *found, nil + } + return database.Workspace{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { + if err := validateDatabaseType(workspaceAppID); err != nil { + return database.Workspace{}, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, workspaceApp := range q.workspaceApps { + workspaceApp := workspaceApp + if workspaceApp.ID == workspaceAppID { + return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID) + } + } + return database.Workspace{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies)) + + for _, p := range q.workspaceProxies { + if !p.Deleted { + cpy = append(cpy, p) + } + } + return cpy, nil +} + +func (q *FakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, params database.GetWorkspaceProxyByHostnameParams) (database.WorkspaceProxy, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + // Return zero rows if this is called with a non-sanitized hostname. The SQL + // version of this query does the same thing. + if !validProxyByHostnameRegex.MatchString(params.Hostname) { + return database.WorkspaceProxy{}, sql.ErrNoRows + } + + // This regex matches the SQL version. + accessURLRegex := regexp.MustCompile(`[^:]*://` + regexp.QuoteMeta(params.Hostname) + `([:/]?.)*`) + + for _, proxy := range q.workspaceProxies { + if proxy.Deleted { + continue + } + if params.AllowAccessUrl && accessURLRegex.MatchString(proxy.Url) { + return proxy, nil + } + + // Compile the app hostname regex. This is slow sadly. + if params.AllowWildcardHostname { + wildcardRegexp, err := httpapi.CompileHostnamePattern(proxy.WildcardHostname) + if err != nil { + return database.WorkspaceProxy{}, xerrors.Errorf("compile hostname pattern %q for proxy %q (%s): %w", proxy.WildcardHostname, proxy.Name, proxy.ID.String(), err) + } + if _, ok := httpapi.ExecuteHostnamePattern(wildcardRegexp, params.Hostname); ok { + return proxy, nil + } + } + } + + return database.WorkspaceProxy{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, proxy := range q.workspaceProxies { + if proxy.ID == id { + return proxy, nil + } + } + return database.WorkspaceProxy{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetWorkspaceProxyByName(_ context.Context, name string) (database.WorkspaceProxy, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, proxy := range q.workspaceProxies { + if proxy.Deleted { + continue + } + if proxy.Name == name { + return proxy, nil + } + } + return database.WorkspaceProxy{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) (database.WorkspaceResource, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, resource := range q.workspaceResources { + if resource.ID == id { + return resource, nil + } + } + return database.WorkspaceResource{}, sql.ErrNoRows +} + +func (q *FakeQuerier) GetWorkspaceResourceMetadataByResourceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + metadata := make([]database.WorkspaceResourceMetadatum, 0) + for _, metadatum := range q.workspaceResourceMetadata { + for _, id := range ids { + if metadatum.WorkspaceResourceID == id { + metadata = append(metadata, metadatum) + } + } + } + return metadata, nil +} + +func (q *FakeQuerier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, after time.Time) ([]database.WorkspaceResourceMetadatum, error) { + resources, err := q.GetWorkspaceResourcesCreatedAfter(ctx, after) + if err != nil { + return nil, err + } + resourceIDs := map[uuid.UUID]struct{}{} + for _, resource := range resources { + resourceIDs[resource.ID] = struct{}{} + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + metadata := make([]database.WorkspaceResourceMetadatum, 0) + for _, m := range q.workspaceResourceMetadata { + _, ok := resourceIDs[m.WorkspaceResourceID] + if !ok { + continue + } + metadata = append(metadata, m) + } + return metadata, nil } -func (q *FakeQuerier) GetWorkspaceBuildsCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceBuild, error) { +func (q *FakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { q.mutex.RLock() defer q.mutex.RUnlock() - workspaceBuilds := make([]database.WorkspaceBuild, 0) - for _, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.CreatedAt.After(after) { - workspaceBuilds = append(workspaceBuilds, workspaceBuild) - } - } - return workspaceBuilds, nil + return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID) } -func (q *FakeQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (database.Workspace, error) { +func (q *FakeQuerier) GetWorkspaceResourcesByJobIDs(_ context.Context, jobIDs []uuid.UUID) ([]database.WorkspaceResource, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.getWorkspaceByAgentIDNoLock(ctx, agentID) + resources := make([]database.WorkspaceResource, 0) + for _, resource := range q.workspaceResources { + for _, jobID := range jobIDs { + if resource.JobID != jobID { + continue + } + resources = append(resources, resource) + } + } + return resources, nil } -func (q *FakeQuerier) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (database.Workspace, error) { +func (q *FakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceResource, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return q.getWorkspaceByIDNoLock(ctx, id) + resources := make([]database.WorkspaceResource, 0) + for _, resource := range q.workspaceResources { + if resource.CreatedAt.After(after) { + resources = append(resources, resource) + } + } + return resources, nil } -func (q *FakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { +func (q *FakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { if err := validateDatabaseType(arg); err != nil { - return database.Workspace{}, err + return nil, err } + // A nil auth filter means no auth filter. + workspaceRows, err := q.GetAuthorizedWorkspaces(ctx, arg, nil) + return workspaceRows, err +} + +func (q *FakeQuerier) GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() - var found *database.Workspace + workspaces := []database.Workspace{} for _, workspace := range q.workspaces { - workspace := workspace - if workspace.OwnerID != arg.OwnerID { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return nil, err + } + + if build.Transition == database.WorkspaceTransitionStart && + !build.Deadline.IsZero() && + build.Deadline.Before(now) && + !workspace.LockedAt.Valid { + workspaces = append(workspaces, workspace) continue } - if !strings.EqualFold(workspace.Name, arg.Name) { + + if build.Transition == database.WorkspaceTransitionStop && + workspace.AutostartSchedule.Valid && + !workspace.LockedAt.Valid { + workspaces = append(workspaces, workspace) continue } - if workspace.Deleted != arg.Deleted { + + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job by ID: %w", err) + } + if db2sdk.ProvisionerJobStatus(job) == codersdk.ProvisionerJobFailed { + workspaces = append(workspaces, workspace) continue } - // Return the most recent workspace with the given name - if found == nil || workspace.CreatedAt.After(found.CreatedAt) { - found = &workspace + template, err := q.GetTemplateByID(ctx, workspace.TemplateID) + if err != nil { + return nil, xerrors.Errorf("get template by ID: %w", err) + } + if !workspace.LockedAt.Valid && template.InactivityTTL > 0 { + workspaces = append(workspaces, workspace) + continue + } + if workspace.LockedAt.Valid && template.LockedTTL > 0 { + workspaces = append(workspaces, workspace) + continue } } - if found != nil { - return *found, nil - } - return database.Workspace{}, sql.ErrNoRows + + return workspaces, nil } -func (q *FakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) { - if err := validateDatabaseType(workspaceAppID); err != nil { - return database.Workspace{}, err +func (q *FakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { + if err := validateDatabaseType(arg); err != nil { + return database.APIKey{}, err } - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() - for _, workspaceApp := range q.workspaceApps { - workspaceApp := workspaceApp - if workspaceApp.ID == workspaceAppID { - return q.getWorkspaceByAgentIDNoLock(context.Background(), workspaceApp.AgentID) + if arg.LifetimeSeconds == 0 { + arg.LifetimeSeconds = 86400 + } + + for _, u := range q.users { + if u.ID == arg.UserID && u.Deleted { + return database.APIKey{}, xerrors.Errorf("refusing to create APIKey for deleted user") } } - return database.Workspace{}, sql.ErrNoRows + + //nolint:gosimple + key := database.APIKey{ + ID: arg.ID, + LifetimeSeconds: arg.LifetimeSeconds, + HashedSecret: arg.HashedSecret, + IPAddress: arg.IPAddress, + UserID: arg.UserID, + ExpiresAt: arg.ExpiresAt, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + LastUsed: arg.LastUsed, + LoginType: arg.LoginType, + Scope: arg.Scope, + TokenName: arg.TokenName, + } + q.apiKeys = append(q.apiKeys, key) + return key, nil } -func (q *FakeQuerier) GetWorkspaceProxies(_ context.Context) ([]database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *FakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) { + return q.InsertGroup(ctx, database.InsertGroupParams{ + ID: orgID, + Name: database.AllUsersGroup, + OrganizationID: orgID, + }) +} - cpy := make([]database.WorkspaceProxy, 0, len(q.workspaceProxies)) +func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { + if err := validateDatabaseType(arg); err != nil { + return database.AuditLog{}, err + } - for _, p := range q.workspaceProxies { - if !p.Deleted { - cpy = append(cpy, p) - } + q.mutex.Lock() + defer q.mutex.Unlock() + + alog := database.AuditLog(arg) + + q.auditLogs = append(q.auditLogs, alog) + slices.SortFunc(q.auditLogs, func(a, b database.AuditLog) bool { + return a.Time.Before(b.Time) + }) + + return alog, nil +} + +func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + q.derpMeshKey = id + return nil +} + +func (q *FakeQuerier) InsertDeploymentID(_ context.Context, id string) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + q.deploymentID = id + return nil +} + +func (q *FakeQuerier) InsertFile(_ context.Context, arg database.InsertFileParams) (database.File, error) { + if err := validateDatabaseType(arg); err != nil { + return database.File{}, err } - return cpy, nil + + q.mutex.Lock() + defer q.mutex.Unlock() + + //nolint:gosimple + file := database.File{ + ID: arg.ID, + Hash: arg.Hash, + CreatedAt: arg.CreatedAt, + CreatedBy: arg.CreatedBy, + Mimetype: arg.Mimetype, + Data: arg.Data, + } + q.files = append(q.files, file) + return file, nil } -func (q *FakeQuerier) GetWorkspaceProxyByHostname(_ context.Context, params database.GetWorkspaceProxyByHostnameParams) (database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { + if err := validateDatabaseType(arg); err != nil { + return database.GitAuthLink{}, err + } - // Return zero rows if this is called with a non-sanitized hostname. The SQL - // version of this query does the same thing. - if !validProxyByHostnameRegex.MatchString(params.Hostname) { - return database.WorkspaceProxy{}, sql.ErrNoRows + q.mutex.Lock() + defer q.mutex.Unlock() + // nolint:gosimple + gitAuthLink := database.GitAuthLink{ + ProviderID: arg.ProviderID, + UserID: arg.UserID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OAuthAccessToken: arg.OAuthAccessToken, + OAuthRefreshToken: arg.OAuthRefreshToken, + OAuthExpiry: arg.OAuthExpiry, } + q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink) + return gitAuthLink, nil +} - // This regex matches the SQL version. - accessURLRegex := regexp.MustCompile(`[^:]*://` + regexp.QuoteMeta(params.Hostname) + `([:/]?.)*`) +func (q *FakeQuerier) InsertGitSSHKey(_ context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { + if err := validateDatabaseType(arg); err != nil { + return database.GitSSHKey{}, err + } - for _, proxy := range q.workspaceProxies { - if proxy.Deleted { - continue - } - if params.AllowAccessUrl && accessURLRegex.MatchString(proxy.Url) { - return proxy, nil - } + q.mutex.Lock() + defer q.mutex.Unlock() - // Compile the app hostname regex. This is slow sadly. - if params.AllowWildcardHostname { - wildcardRegexp, err := httpapi.CompileHostnamePattern(proxy.WildcardHostname) - if err != nil { - return database.WorkspaceProxy{}, xerrors.Errorf("compile hostname pattern %q for proxy %q (%s): %w", proxy.WildcardHostname, proxy.Name, proxy.ID.String(), err) - } - if _, ok := httpapi.ExecuteHostnamePattern(wildcardRegexp, params.Hostname); ok { - return proxy, nil - } - } + //nolint:gosimple + gitSSHKey := database.GitSSHKey{ + UserID: arg.UserID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + PrivateKey: arg.PrivateKey, + PublicKey: arg.PublicKey, } - - return database.WorkspaceProxy{}, sql.ErrNoRows + q.gitSSHKey = append(q.gitSSHKey, gitSSHKey) + return gitSSHKey, nil } -func (q *FakeQuerier) GetWorkspaceProxyByID(_ context.Context, id uuid.UUID) (database.WorkspaceProxy, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - for _, proxy := range q.workspaceProxies { - if proxy.ID == id { - return proxy, nil - } +func (q *FakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupParams) (database.Group, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Group{}, err } - return database.WorkspaceProxy{}, sql.ErrNoRows -} -func (q *FakeQuerier) GetWorkspaceProxyByName(_ context.Context, name string) (database.WorkspaceProxy, error) { q.mutex.Lock() defer q.mutex.Unlock() - for _, proxy := range q.workspaceProxies { - if proxy.Deleted { - continue - } - if proxy.Name == name { - return proxy, nil + for _, group := range q.groups { + if group.OrganizationID == arg.OrganizationID && + group.Name == arg.Name { + return database.Group{}, errDuplicateKey } } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) (database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - for _, resource := range q.workspaceResources { - if resource.ID == id { - return resource, nil - } + //nolint:gosimple + group := database.Group{ + ID: arg.ID, + Name: arg.Name, + OrganizationID: arg.OrganizationID, + AvatarURL: arg.AvatarURL, + QuotaAllowance: arg.QuotaAllowance, } - return database.WorkspaceResource{}, sql.ErrNoRows -} -func (q *FakeQuerier) GetWorkspaceResourceMetadataByResourceIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceResourceMetadatum, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + q.groups = append(q.groups, group) - metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, metadatum := range q.workspaceResourceMetadata { - for _, id := range ids { - if metadatum.WorkspaceResourceID == id { - metadata = append(metadata, metadatum) - } - } - } - return metadata, nil + return group, nil } -func (q *FakeQuerier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, after time.Time) ([]database.WorkspaceResourceMetadatum, error) { - resources, err := q.GetWorkspaceResourcesCreatedAfter(ctx, after) - if err != nil { - return nil, err - } - resourceIDs := map[uuid.UUID]struct{}{} - for _, resource := range resources { - resourceIDs[resource.ID] = struct{}{} +func (q *FakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGroupMemberParams) error { + if err := validateDatabaseType(arg); err != nil { + return err } - q.mutex.RLock() - defer q.mutex.RUnlock() + q.mutex.Lock() + defer q.mutex.Unlock() - metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, m := range q.workspaceResourceMetadata { - _, ok := resourceIDs[m.WorkspaceResourceID] - if !ok { - continue + for _, member := range q.groupMembers { + if member.GroupID == arg.GroupID && + member.UserID == arg.UserID { + return errDuplicateKey } - metadata = append(metadata, m) } - return metadata, nil -} -func (q *FakeQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() + //nolint:gosimple + q.groupMembers = append(q.groupMembers, database.GroupMember{ + GroupID: arg.GroupID, + UserID: arg.UserID, + }) - return q.getWorkspaceResourcesByJobIDNoLock(ctx, jobID) + return nil } -func (q *FakeQuerier) GetWorkspaceResourcesByJobIDs(_ context.Context, jobIDs []uuid.UUID) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *FakeQuerier) InsertLicense( + _ context.Context, arg database.InsertLicenseParams, +) (database.License, error) { + if err := validateDatabaseType(arg); err != nil { + return database.License{}, err + } - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - for _, jobID := range jobIDs { - if resource.JobID != jobID { - continue - } - resources = append(resources, resource) - } + q.mutex.Lock() + defer q.mutex.Unlock() + + l := database.License{ + ID: q.lastLicenseID + 1, + UploadedAt: arg.UploadedAt, + JWT: arg.JWT, + Exp: arg.Exp, } - return resources, nil + q.lastLicenseID = l.ID + q.licenses = append(q.licenses, l) + return l, nil } -func (q *FakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after time.Time) ([]database.WorkspaceResource, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *FakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Organization{}, err + } - resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.workspaceResources { - if resource.CreatedAt.After(after) { - resources = append(resources, resource) - } + q.mutex.Lock() + defer q.mutex.Unlock() + + organization := database.Organization{ + ID: arg.ID, + Name: arg.Name, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, } - return resources, nil + q.organizations = append(q.organizations, organization) + return organization, nil } -func (q *FakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspacesParams) ([]database.GetWorkspacesRow, error) { +func (q *FakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { if err := validateDatabaseType(arg); err != nil { - return nil, err + return database.OrganizationMember{}, err } - // A nil auth filter means no auth filter. - workspaceRows, err := q.GetAuthorizedWorkspaces(ctx, arg, nil) - return workspaceRows, err + q.mutex.Lock() + defer q.mutex.Unlock() + + //nolint:gosimple + organizationMember := database.OrganizationMember{ + OrganizationID: arg.OrganizationID, + UserID: arg.UserID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Roles: arg.Roles, + } + q.organizationMembers = append(q.organizationMembers, organizationMember) + return organizationMember, nil } -func (q *FakeQuerier) GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]database.Workspace, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() +func (q *FakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { + if err := validateDatabaseType(arg); err != nil { + return database.ProvisionerDaemon{}, err + } - workspaces := []database.Workspace{} - for _, workspace := range q.workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, err - } + q.mutex.Lock() + defer q.mutex.Unlock() - if build.Transition == database.WorkspaceTransitionStart && - !build.Deadline.IsZero() && - build.Deadline.Before(now) && - !workspace.LockedAt.Valid { - workspaces = append(workspaces, workspace) - continue - } + daemon := database.ProvisionerDaemon{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + Name: arg.Name, + Provisioners: arg.Provisioners, + Tags: arg.Tags, + } + q.provisionerDaemons = append(q.provisionerDaemons, daemon) + return daemon, nil +} - if build.Transition == database.WorkspaceTransitionStop && - workspace.AutostartSchedule.Valid && - !workspace.LockedAt.Valid { - workspaces = append(workspaces, workspace) - continue - } +func (q *FakeQuerier) InsertProvisionerJob(_ context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { + if err := validateDatabaseType(arg); err != nil { + return database.ProvisionerJob{}, err + } - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job by ID: %w", err) - } - if db2sdk.ProvisionerJobStatus(job) == codersdk.ProvisionerJobFailed { - workspaces = append(workspaces, workspace) - continue - } + q.mutex.Lock() + defer q.mutex.Unlock() - template, err := q.GetTemplateByID(ctx, workspace.TemplateID) - if err != nil { - return nil, xerrors.Errorf("get template by ID: %w", err) - } - if !workspace.LockedAt.Valid && template.InactivityTTL > 0 { - workspaces = append(workspaces, workspace) - continue - } - if workspace.LockedAt.Valid && template.LockedTTL > 0 { - workspaces = append(workspaces, workspace) - continue - } + job := database.ProvisionerJob{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OrganizationID: arg.OrganizationID, + InitiatorID: arg.InitiatorID, + Provisioner: arg.Provisioner, + StorageMethod: arg.StorageMethod, + FileID: arg.FileID, + Type: arg.Type, + Input: arg.Input, + Tags: arg.Tags, } - - return workspaces, nil + q.provisionerJobs = append(q.provisionerJobs, job) + return job, nil } -func (q *FakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { +func (q *FakeQuerier) InsertProvisionerJobLogs(_ context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { if err := validateDatabaseType(arg); err != nil { - return database.APIKey{}, err + return nil, err } q.mutex.Lock() defer q.mutex.Unlock() - if arg.LifetimeSeconds == 0 { - arg.LifetimeSeconds = 86400 + logs := make([]database.ProvisionerJobLog, 0) + id := int64(1) + if len(q.provisionerJobLogs) > 0 { + id = q.provisionerJobLogs[len(q.provisionerJobLogs)-1].ID + } + for index, output := range arg.Output { + id++ + logs = append(logs, database.ProvisionerJobLog{ + ID: id, + JobID: arg.JobID, + CreatedAt: arg.CreatedAt[index], + Source: arg.Source[index], + Level: arg.Level[index], + Stage: arg.Stage[index], + Output: output, + }) } + q.provisionerJobLogs = append(q.provisionerJobLogs, logs...) + return logs, nil +} - for _, u := range q.users { - if u.ID == arg.UserID && u.Deleted { - return database.APIKey{}, xerrors.Errorf("refusing to create APIKey for deleted user") - } +func (q *FakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Replica{}, err } - //nolint:gosimple - key := database.APIKey{ + q.mutex.Lock() + defer q.mutex.Unlock() + + replica := database.Replica{ ID: arg.ID, - LifetimeSeconds: arg.LifetimeSeconds, - HashedSecret: arg.HashedSecret, - IPAddress: arg.IPAddress, - UserID: arg.UserID, - ExpiresAt: arg.ExpiresAt, CreatedAt: arg.CreatedAt, + StartedAt: arg.StartedAt, UpdatedAt: arg.UpdatedAt, - LastUsed: arg.LastUsed, - LoginType: arg.LoginType, - Scope: arg.Scope, - TokenName: arg.TokenName, + Hostname: arg.Hostname, + RegionID: arg.RegionID, + RelayAddress: arg.RelayAddress, + Version: arg.Version, + DatabaseLatency: arg.DatabaseLatency, } - q.apiKeys = append(q.apiKeys, key) - return key, nil + q.replicas = append(q.replicas, replica) + return replica, nil } -func (q *FakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) { - return q.InsertGroup(ctx, database.InsertGroupParams{ - ID: orgID, - Name: database.AllUsersGroup, - OrganizationID: orgID, - }) +func (q *FakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTemplateParams) (database.Template, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Template{}, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + //nolint:gosimple + template := database.Template{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + OrganizationID: arg.OrganizationID, + Name: arg.Name, + Provisioner: arg.Provisioner, + ActiveVersionID: arg.ActiveVersionID, + Description: arg.Description, + CreatedBy: arg.CreatedBy, + UserACL: arg.UserACL, + GroupACL: arg.GroupACL, + DisplayName: arg.DisplayName, + Icon: arg.Icon, + AllowUserCancelWorkspaceJobs: arg.AllowUserCancelWorkspaceJobs, + AllowUserAutostart: true, + AllowUserAutostop: true, + } + q.templates = append(q.templates, template) + return template.DeepCopy(), nil } -func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) { +func (q *FakeQuerier) InsertTemplateVersion(_ context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) { if err := validateDatabaseType(arg); err != nil { - return database.AuditLog{}, err + return database.TemplateVersion{}, err + } + + if len(arg.Message) > 1048576 { + return database.TemplateVersion{}, xerrors.New("message too long") } q.mutex.Lock() defer q.mutex.Unlock() - alog := database.AuditLog(arg) + //nolint:gosimple + version := database.TemplateVersion{ + ID: arg.ID, + TemplateID: arg.TemplateID, + OrganizationID: arg.OrganizationID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Name: arg.Name, + Message: arg.Message, + Readme: arg.Readme, + JobID: arg.JobID, + CreatedBy: arg.CreatedBy, + } + q.templateVersions = append(q.templateVersions, version) + return version, nil +} - q.auditLogs = append(q.auditLogs, alog) - slices.SortFunc(q.auditLogs, func(a, b database.AuditLog) bool { - return a.Time.Before(b.Time) - }) +func (q *FakeQuerier) InsertTemplateVersionParameter(_ context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { + if err := validateDatabaseType(arg); err != nil { + return database.TemplateVersionParameter{}, err + } - return alog, nil + q.mutex.Lock() + defer q.mutex.Unlock() + + //nolint:gosimple + param := database.TemplateVersionParameter{ + TemplateVersionID: arg.TemplateVersionID, + Name: arg.Name, + DisplayName: arg.DisplayName, + Description: arg.Description, + Type: arg.Type, + Mutable: arg.Mutable, + DefaultValue: arg.DefaultValue, + Icon: arg.Icon, + Options: arg.Options, + ValidationError: arg.ValidationError, + ValidationRegex: arg.ValidationRegex, + ValidationMin: arg.ValidationMin, + ValidationMax: arg.ValidationMax, + ValidationMonotonic: arg.ValidationMonotonic, + Required: arg.Required, + DisplayOrder: arg.DisplayOrder, + Ephemeral: arg.Ephemeral, + } + q.templateVersionParameters = append(q.templateVersionParameters, param) + return param, nil } -func (q *FakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error { +func (q *FakeQuerier) InsertTemplateVersionVariable(_ context.Context, arg database.InsertTemplateVersionVariableParams) (database.TemplateVersionVariable, error) { + if err := validateDatabaseType(arg); err != nil { + return database.TemplateVersionVariable{}, err + } + q.mutex.Lock() defer q.mutex.Unlock() - q.derpMeshKey = id - return nil + //nolint:gosimple + variable := database.TemplateVersionVariable{ + TemplateVersionID: arg.TemplateVersionID, + Name: arg.Name, + Description: arg.Description, + Type: arg.Type, + Value: arg.Value, + DefaultValue: arg.DefaultValue, + Required: arg.Required, + Sensitive: arg.Sensitive, + } + q.templateVersionVariables = append(q.templateVersionVariables, variable) + return variable, nil } -func (q *FakeQuerier) InsertDeploymentID(_ context.Context, id string) error { +func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { + if err := validateDatabaseType(arg); err != nil { + return database.User{}, err + } + + // There is a common bug when using dbfake that 2 inserted users have the + // same created_at time. This causes user order to not be deterministic, + // which breaks some unit tests. + // To fix this, we make sure that the created_at time is always greater + // than the last user's created_at time. + allUsers, _ := q.GetUsers(context.Background(), database.GetUsersParams{}) + if len(allUsers) > 0 { + lastUser := allUsers[len(allUsers)-1] + if arg.CreatedAt.Before(lastUser.CreatedAt) || + arg.CreatedAt.Equal(lastUser.CreatedAt) { + // 1 ms is a good enough buffer. + arg.CreatedAt = lastUser.CreatedAt.Add(time.Millisecond) + } + } + q.mutex.Lock() defer q.mutex.Unlock() - q.deploymentID = id - return nil + for _, user := range q.users { + if user.Username == arg.Username && !user.Deleted { + return database.User{}, errDuplicateKey + } + } + + user := database.User{ + ID: arg.ID, + Email: arg.Email, + HashedPassword: arg.HashedPassword, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Username: arg.Username, + Status: database.UserStatusActive, + RBACRoles: arg.RBACRoles, + LoginType: arg.LoginType, + } + q.users = append(q.users, user) + return user, nil } -func (q *FakeQuerier) InsertFile(_ context.Context, arg database.InsertFileParams) (database.File, error) { - if err := validateDatabaseType(arg); err != nil { - return database.File{}, err +func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { + q.mutex.Lock() + defer q.mutex.Unlock() + + var groupIDs []uuid.UUID + for _, group := range q.groups { + for _, groupName := range arg.GroupNames { + if group.Name == groupName { + groupIDs = append(groupIDs, group.ID) + } + } + } + + for _, groupID := range groupIDs { + q.groupMembers = append(q.groupMembers, database.GroupMember{ + UserID: arg.UserID, + GroupID: groupID, + }) } + return nil +} + +func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { q.mutex.Lock() defer q.mutex.Unlock() //nolint:gosimple - file := database.File{ - ID: arg.ID, - Hash: arg.Hash, - CreatedAt: arg.CreatedAt, - CreatedBy: arg.CreatedBy, - Mimetype: arg.Mimetype, - Data: arg.Data, - } - q.files = append(q.files, file) - return file, nil + link := database.UserLink{ + UserID: args.UserID, + LoginType: args.LoginType, + LinkedID: args.LinkedID, + OAuthAccessToken: args.OAuthAccessToken, + OAuthRefreshToken: args.OAuthRefreshToken, + OAuthExpiry: args.OAuthExpiry, + } + + q.userLinks = append(q.userLinks, link) + + return link, nil } -func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGitAuthLinkParams) (database.GitAuthLink, error) { +func (q *FakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) { if err := validateDatabaseType(arg); err != nil { - return database.GitAuthLink{}, err + return database.Workspace{}, err } q.mutex.Lock() defer q.mutex.Unlock() - // nolint:gosimple - gitAuthLink := database.GitAuthLink{ - ProviderID: arg.ProviderID, - UserID: arg.UserID, + + //nolint:gosimple + workspace := database.Workspace{ + ID: arg.ID, CreatedAt: arg.CreatedAt, UpdatedAt: arg.UpdatedAt, - OAuthAccessToken: arg.OAuthAccessToken, - OAuthRefreshToken: arg.OAuthRefreshToken, - OAuthExpiry: arg.OAuthExpiry, + OwnerID: arg.OwnerID, + OrganizationID: arg.OrganizationID, + TemplateID: arg.TemplateID, + Name: arg.Name, + AutostartSchedule: arg.AutostartSchedule, + Ttl: arg.Ttl, + LastUsedAt: arg.LastUsedAt, } - q.gitAuthLinks = append(q.gitAuthLinks, gitAuthLink) - return gitAuthLink, nil + q.workspaces = append(q.workspaces, workspace) + return workspace, nil } -func (q *FakeQuerier) InsertGitSSHKey(_ context.Context, arg database.InsertGitSSHKeyParams) (database.GitSSHKey, error) { +func (q *FakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { if err := validateDatabaseType(arg); err != nil { - return database.GitSSHKey{}, err + return database.WorkspaceAgent{}, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - gitSSHKey := database.GitSSHKey{ - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - PrivateKey: arg.PrivateKey, - PublicKey: arg.PublicKey, + agent := database.WorkspaceAgent{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + ResourceID: arg.ResourceID, + AuthToken: arg.AuthToken, + AuthInstanceID: arg.AuthInstanceID, + EnvironmentVariables: arg.EnvironmentVariables, + Name: arg.Name, + Architecture: arg.Architecture, + OperatingSystem: arg.OperatingSystem, + Directory: arg.Directory, + StartupScriptBehavior: arg.StartupScriptBehavior, + StartupScript: arg.StartupScript, + InstanceMetadata: arg.InstanceMetadata, + ResourceMetadata: arg.ResourceMetadata, + ConnectionTimeoutSeconds: arg.ConnectionTimeoutSeconds, + TroubleshootingURL: arg.TroubleshootingURL, + MOTDFile: arg.MOTDFile, + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + ShutdownScript: arg.ShutdownScript, } - q.gitSSHKey = append(q.gitSSHKey, gitSSHKey) - return gitSSHKey, nil -} -func (q *FakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupParams) (database.Group, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err - } + q.workspaceAgents = append(q.workspaceAgents, agent) + return agent, nil +} +func (q *FakeQuerier) InsertWorkspaceAgentMetadata(_ context.Context, arg database.InsertWorkspaceAgentMetadataParams) error { q.mutex.Lock() defer q.mutex.Unlock() - for _, group := range q.groups { - if group.OrganizationID == arg.OrganizationID && - group.Name == arg.Name { - return database.Group{}, errDuplicateKey - } - } - //nolint:gosimple - group := database.Group{ - ID: arg.ID, - Name: arg.Name, - OrganizationID: arg.OrganizationID, - AvatarURL: arg.AvatarURL, - QuotaAllowance: arg.QuotaAllowance, + metadatum := database.WorkspaceAgentMetadatum{ + WorkspaceAgentID: arg.WorkspaceAgentID, + Script: arg.Script, + DisplayName: arg.DisplayName, + Key: arg.Key, + Timeout: arg.Timeout, + Interval: arg.Interval, } - q.groups = append(q.groups, group) - - return group, nil + q.workspaceAgentMetadata = append(q.workspaceAgentMetadata, metadatum) + return nil } -func (q *FakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGroupMemberParams) error { +func (q *FakeQuerier) InsertWorkspaceAgentStartupLogs(_ context.Context, arg database.InsertWorkspaceAgentStartupLogsParams) ([]database.WorkspaceAgentStartupLog, error) { if err := validateDatabaseType(arg); err != nil { - return err + return nil, err } q.mutex.Lock() defer q.mutex.Unlock() - for _, member := range q.groupMembers { - if member.GroupID == arg.GroupID && - member.UserID == arg.UserID { - return errDuplicateKey + logs := []database.WorkspaceAgentStartupLog{} + id := int64(0) + if len(q.workspaceAgentLogs) > 0 { + id = q.workspaceAgentLogs[len(q.workspaceAgentLogs)-1].ID + } + outputLength := int32(0) + for index, output := range arg.Output { + id++ + logs = append(logs, database.WorkspaceAgentStartupLog{ + ID: id, + AgentID: arg.AgentID, + CreatedAt: arg.CreatedAt[index], + Level: arg.Level[index], + Output: output, + }) + outputLength += int32(len(output)) + } + for index, agent := range q.workspaceAgents { + if agent.ID != arg.AgentID { + continue + } + // Greater than 1MB, same as the PostgreSQL constraint! + if agent.StartupLogsLength+outputLength > (1 << 20) { + return nil, &pq.Error{ + Constraint: "max_startup_logs_length", + Table: "workspace_agents", + } } + agent.StartupLogsLength += outputLength + q.workspaceAgents[index] = agent + break } - - //nolint:gosimple - q.groupMembers = append(q.groupMembers, database.GroupMember{ - GroupID: arg.GroupID, - UserID: arg.UserID, - }) - - return nil + q.workspaceAgentLogs = append(q.workspaceAgentLogs, logs...) + return logs, nil } -func (q *FakeQuerier) InsertLicense( - _ context.Context, arg database.InsertLicenseParams, -) (database.License, error) { - if err := validateDatabaseType(arg); err != nil { - return database.License{}, err +func (q *FakeQuerier) InsertWorkspaceAgentStat(_ context.Context, p database.InsertWorkspaceAgentStatParams) (database.WorkspaceAgentStat, error) { + if err := validateDatabaseType(p); err != nil { + return database.WorkspaceAgentStat{}, err } q.mutex.Lock() defer q.mutex.Unlock() - l := database.License{ - ID: q.lastLicenseID + 1, - UploadedAt: arg.UploadedAt, - JWT: arg.JWT, - Exp: arg.Exp, + stat := database.WorkspaceAgentStat{ + ID: p.ID, + CreatedAt: p.CreatedAt, + WorkspaceID: p.WorkspaceID, + AgentID: p.AgentID, + UserID: p.UserID, + ConnectionsByProto: p.ConnectionsByProto, + ConnectionCount: p.ConnectionCount, + RxPackets: p.RxPackets, + RxBytes: p.RxBytes, + TxPackets: p.TxPackets, + TxBytes: p.TxBytes, + TemplateID: p.TemplateID, + SessionCountVSCode: p.SessionCountVSCode, + SessionCountJetBrains: p.SessionCountJetBrains, + SessionCountReconnectingPTY: p.SessionCountReconnectingPTY, + SessionCountSSH: p.SessionCountSSH, + ConnectionMedianLatencyMS: p.ConnectionMedianLatencyMS, } - q.lastLicenseID = l.ID - q.licenses = append(q.licenses, l) - return l, nil + q.workspaceAgentStats = append(q.workspaceAgentStats, stat) + return stat, nil } -func (q *FakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) { +func (q *FakeQuerier) InsertWorkspaceApp(_ context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) { if err := validateDatabaseType(arg); err != nil { - return database.Organization{}, err + return database.WorkspaceApp{}, err } q.mutex.Lock() defer q.mutex.Unlock() - organization := database.Organization{ - ID: arg.ID, - Name: arg.Name, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, + if arg.SharingLevel == "" { + arg.SharingLevel = database.AppSharingLevelOwner } - q.organizations = append(q.organizations, organization) - return organization, nil + + // nolint:gosimple + workspaceApp := database.WorkspaceApp{ + ID: arg.ID, + AgentID: arg.AgentID, + CreatedAt: arg.CreatedAt, + Slug: arg.Slug, + DisplayName: arg.DisplayName, + Icon: arg.Icon, + Command: arg.Command, + Url: arg.Url, + External: arg.External, + Subdomain: arg.Subdomain, + SharingLevel: arg.SharingLevel, + HealthcheckUrl: arg.HealthcheckUrl, + HealthcheckInterval: arg.HealthcheckInterval, + HealthcheckThreshold: arg.HealthcheckThreshold, + Health: arg.Health, + } + q.workspaceApps = append(q.workspaceApps, workspaceApp) + return workspaceApp, nil } -func (q *FakeQuerier) InsertOrganizationMember(_ context.Context, arg database.InsertOrganizationMemberParams) (database.OrganizationMember, error) { +func (q *FakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.InsertWorkspaceBuildParams) (database.WorkspaceBuild, error) { if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err + return database.WorkspaceBuild{}, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - organizationMember := database.OrganizationMember{ - OrganizationID: arg.OrganizationID, - UserID: arg.UserID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Roles: arg.Roles, + workspaceBuild := database.WorkspaceBuild{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + WorkspaceID: arg.WorkspaceID, + TemplateVersionID: arg.TemplateVersionID, + BuildNumber: arg.BuildNumber, + Transition: arg.Transition, + InitiatorID: arg.InitiatorID, + JobID: arg.JobID, + ProvisionerState: arg.ProvisionerState, + Deadline: arg.Deadline, + Reason: arg.Reason, } - q.organizationMembers = append(q.organizationMembers, organizationMember) - return organizationMember, nil + q.workspaceBuilds = append(q.workspaceBuilds, workspaceBuild) + return workspaceBuild, nil } -func (q *FakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) { +func (q *FakeQuerier) InsertWorkspaceBuildParameters(_ context.Context, arg database.InsertWorkspaceBuildParametersParams) error { if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerDaemon{}, err + return err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for index, name := range arg.Name { + q.workspaceBuildParameters = append(q.workspaceBuildParameters, database.WorkspaceBuildParameter{ + WorkspaceBuildID: arg.WorkspaceBuildID, + Name: name, + Value: arg.Value[index], + }) } + return nil +} +func (q *FakeQuerier) InsertWorkspaceProxy(_ context.Context, arg database.InsertWorkspaceProxyParams) (database.WorkspaceProxy, error) { q.mutex.Lock() defer q.mutex.Unlock() - daemon := database.ProvisionerDaemon{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - Name: arg.Name, - Provisioners: arg.Provisioners, - Tags: arg.Tags, + for _, p := range q.workspaceProxies { + if !p.Deleted && p.Name == arg.Name { + return database.WorkspaceProxy{}, errDuplicateKey + } } - q.provisionerDaemons = append(q.provisionerDaemons, daemon) - return daemon, nil + + p := database.WorkspaceProxy{ + ID: arg.ID, + Name: arg.Name, + DisplayName: arg.DisplayName, + Icon: arg.Icon, + TokenHashedSecret: arg.TokenHashedSecret, + CreatedAt: arg.CreatedAt, + UpdatedAt: arg.UpdatedAt, + Deleted: false, + } + q.workspaceProxies = append(q.workspaceProxies, p) + return p, nil } -func (q *FakeQuerier) InsertProvisionerJob(_ context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) { +func (q *FakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { if err := validateDatabaseType(arg); err != nil { - return database.ProvisionerJob{}, err + return database.WorkspaceResource{}, err } q.mutex.Lock() defer q.mutex.Unlock() - job := database.ProvisionerJob{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OrganizationID: arg.OrganizationID, - InitiatorID: arg.InitiatorID, - Provisioner: arg.Provisioner, - StorageMethod: arg.StorageMethod, - FileID: arg.FileID, - Type: arg.Type, - Input: arg.Input, - Tags: arg.Tags, + //nolint:gosimple + resource := database.WorkspaceResource{ + ID: arg.ID, + CreatedAt: arg.CreatedAt, + JobID: arg.JobID, + Transition: arg.Transition, + Type: arg.Type, + Name: arg.Name, + Hide: arg.Hide, + Icon: arg.Icon, + DailyCost: arg.DailyCost, } - q.provisionerJobs = append(q.provisionerJobs, job) - return job, nil + q.workspaceResources = append(q.workspaceResources, resource) + return resource, nil } -func (q *FakeQuerier) InsertProvisionerJobLogs(_ context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) { +func (q *FakeQuerier) InsertWorkspaceResourceMetadata(_ context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { if err := validateDatabaseType(arg); err != nil { return nil, err } @@ -3864,465 +3931,388 @@ func (q *FakeQuerier) InsertProvisionerJobLogs(_ context.Context, arg database.I q.mutex.Lock() defer q.mutex.Unlock() - logs := make([]database.ProvisionerJobLog, 0) + metadata := make([]database.WorkspaceResourceMetadatum, 0) id := int64(1) - if len(q.provisionerJobLogs) > 0 { - id = q.provisionerJobLogs[len(q.provisionerJobLogs)-1].ID + if len(q.workspaceResourceMetadata) > 0 { + id = q.workspaceResourceMetadata[len(q.workspaceResourceMetadata)-1].ID } - for index, output := range arg.Output { + for index, key := range arg.Key { id++ - logs = append(logs, database.ProvisionerJobLog{ - ID: id, - JobID: arg.JobID, - CreatedAt: arg.CreatedAt[index], - Source: arg.Source[index], - Level: arg.Level[index], - Stage: arg.Stage[index], - Output: output, + value := arg.Value[index] + metadata = append(metadata, database.WorkspaceResourceMetadatum{ + ID: id, + WorkspaceResourceID: arg.WorkspaceResourceID, + Key: key, + Value: sql.NullString{ + String: value, + Valid: value != "", + }, + Sensitive: arg.Sensitive[index], }) } - q.provisionerJobLogs = append(q.provisionerJobLogs, logs...) - return logs, nil + q.workspaceResourceMetadata = append(q.workspaceResourceMetadata, metadata...) + return metadata, nil } -func (q *FakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Replica{}, err - } - +func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { q.mutex.Lock() defer q.mutex.Unlock() - replica := database.Replica{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - StartedAt: arg.StartedAt, - UpdatedAt: arg.UpdatedAt, - Hostname: arg.Hostname, - RegionID: arg.RegionID, - RelayAddress: arg.RelayAddress, - Version: arg.Version, - DatabaseLatency: arg.DatabaseLatency, + for i, p := range q.workspaceProxies { + if p.ID == arg.ID { + p.Url = arg.Url + p.WildcardHostname = arg.WildcardHostname + p.UpdatedAt = database.Now() + q.workspaceProxies[i] = p + return p, nil + } } - q.replicas = append(q.replicas, replica) - return replica, nil + return database.WorkspaceProxy{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTemplateParams) (database.Template, error) { +func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) { + return false, xerrors.New("TryAcquireLock must only be called within a transaction") +} + +func (q *FakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error { if err := validateDatabaseType(arg); err != nil { - return database.Template{}, err + return err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - template := database.Template{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OrganizationID: arg.OrganizationID, - Name: arg.Name, - Provisioner: arg.Provisioner, - ActiveVersionID: arg.ActiveVersionID, - Description: arg.Description, - CreatedBy: arg.CreatedBy, - UserACL: arg.UserACL, - GroupACL: arg.GroupACL, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - AllowUserCancelWorkspaceJobs: arg.AllowUserCancelWorkspaceJobs, - AllowUserAutostart: true, - AllowUserAutostop: true, + for index, apiKey := range q.apiKeys { + if apiKey.ID != arg.ID { + continue + } + apiKey.LastUsed = arg.LastUsed + apiKey.ExpiresAt = arg.ExpiresAt + apiKey.IPAddress = arg.IPAddress + q.apiKeys[index] = apiKey + return nil } - q.templates = append(q.templates, template) - return template.DeepCopy(), nil + return sql.ErrNoRows } -func (q *FakeQuerier) InsertTemplateVersion(_ context.Context, arg database.InsertTemplateVersionParams) (database.TemplateVersion, error) { +func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err - } - - if len(arg.Message) > 1048576 { - return database.TemplateVersion{}, xerrors.New("message too long") + return database.GitAuthLink{}, err } q.mutex.Lock() defer q.mutex.Unlock() + for index, gitAuthLink := range q.gitAuthLinks { + if gitAuthLink.ProviderID != arg.ProviderID { + continue + } + if gitAuthLink.UserID != arg.UserID { + continue + } + gitAuthLink.UpdatedAt = arg.UpdatedAt + gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken + gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken + gitAuthLink.OAuthExpiry = arg.OAuthExpiry + q.gitAuthLinks[index] = gitAuthLink - //nolint:gosimple - version := database.TemplateVersion{ - ID: arg.ID, - TemplateID: arg.TemplateID, - OrganizationID: arg.OrganizationID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Name: arg.Name, - Message: arg.Message, - Readme: arg.Readme, - JobID: arg.JobID, - CreatedBy: arg.CreatedBy, + return gitAuthLink, nil } - q.templateVersions = append(q.templateVersions, version) - return version, nil + return database.GitAuthLink{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertTemplateVersionParameter(_ context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) { +func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersionParameter{}, err + return database.GitSSHKey{}, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - param := database.TemplateVersionParameter{ - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Description: arg.Description, - Type: arg.Type, - Mutable: arg.Mutable, - DefaultValue: arg.DefaultValue, - Icon: arg.Icon, - Options: arg.Options, - ValidationError: arg.ValidationError, - ValidationRegex: arg.ValidationRegex, - ValidationMin: arg.ValidationMin, - ValidationMax: arg.ValidationMax, - ValidationMonotonic: arg.ValidationMonotonic, - Required: arg.Required, - DisplayOrder: arg.DisplayOrder, - Ephemeral: arg.Ephemeral, + for index, key := range q.gitSSHKey { + if key.UserID != arg.UserID { + continue + } + key.UpdatedAt = arg.UpdatedAt + key.PrivateKey = arg.PrivateKey + key.PublicKey = arg.PublicKey + q.gitSSHKey[index] = key + return key, nil } - q.templateVersionParameters = append(q.templateVersionParameters, param) - return param, nil + return database.GitSSHKey{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertTemplateVersionVariable(_ context.Context, arg database.InsertTemplateVersionVariableParams) (database.TemplateVersionVariable, error) { +func (q *FakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersionVariable{}, err + return database.Group{}, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - variable := database.TemplateVersionVariable{ - TemplateVersionID: arg.TemplateVersionID, - Name: arg.Name, - Description: arg.Description, - Type: arg.Type, - Value: arg.Value, - DefaultValue: arg.DefaultValue, - Required: arg.Required, - Sensitive: arg.Sensitive, - } - q.templateVersionVariables = append(q.templateVersionVariables, variable) - return variable, nil + for i, group := range q.groups { + if group.ID == arg.ID { + group.Name = arg.Name + group.AvatarURL = arg.AvatarURL + group.QuotaAllowance = arg.QuotaAllowance + q.groups[i] = group + return group, nil + } + } + return database.Group{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { +func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { if err := validateDatabaseType(arg); err != nil { - return database.User{}, err + return database.OrganizationMember{}, err } - // There is a common bug when using dbfake that 2 inserted users have the - // same created_at time. This causes user order to not be deterministic, - // which breaks some unit tests. - // To fix this, we make sure that the created_at time is always greater - // than the last user's created_at time. - allUsers, _ := q.GetUsers(context.Background(), database.GetUsersParams{}) - if len(allUsers) > 0 { - lastUser := allUsers[len(allUsers)-1] - if arg.CreatedAt.Before(lastUser.CreatedAt) || - arg.CreatedAt.Equal(lastUser.CreatedAt) { - // 1 ms is a good enough buffer. - arg.CreatedAt = lastUser.CreatedAt.Add(time.Millisecond) + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, mem := range q.organizationMembers { + if mem.UserID == arg.UserID && mem.OrganizationID == arg.OrgID { + uniqueRoles := make([]string, 0, len(arg.GrantedRoles)) + exist := make(map[string]struct{}) + for _, r := range arg.GrantedRoles { + if _, ok := exist[r]; ok { + continue + } + exist[r] = struct{}{} + uniqueRoles = append(uniqueRoles, r) + } + sort.Strings(uniqueRoles) + + mem.Roles = uniqueRoles + q.organizationMembers[i] = mem + return mem, nil } } + return database.OrganizationMember{}, sql.ErrNoRows +} + +func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + q.mutex.Lock() defer q.mutex.Unlock() - for _, user := range q.users { - if user.Username == arg.Username && !user.Deleted { - return database.User{}, errDuplicateKey + for index, job := range q.provisionerJobs { + if arg.ID != job.ID { + continue } + job.UpdatedAt = arg.UpdatedAt + q.provisionerJobs[index] = job + return nil } + return sql.ErrNoRows +} - user := database.User{ - ID: arg.ID, - Email: arg.Email, - HashedPassword: arg.HashedPassword, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Username: arg.Username, - Status: database.UserStatusActive, - RBACRoles: arg.RBACRoles, - LoginType: arg.LoginType, +func (q *FakeQuerier) UpdateProvisionerJobWithCancelByID(_ context.Context, arg database.UpdateProvisionerJobWithCancelByIDParams) error { + if err := validateDatabaseType(arg); err != nil { + return err } - q.users = append(q.users, user) - return user, nil -} -func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { q.mutex.Lock() defer q.mutex.Unlock() - var groupIDs []uuid.UUID - for _, group := range q.groups { - for _, groupName := range arg.GroupNames { - if group.Name == groupName { - groupIDs = append(groupIDs, group.ID) - } + for index, job := range q.provisionerJobs { + if arg.ID != job.ID { + continue } + job.CanceledAt = arg.CanceledAt + job.CompletedAt = arg.CompletedAt + q.provisionerJobs[index] = job + return nil } + return sql.ErrNoRows +} - for _, groupID := range groupIDs { - q.groupMembers = append(q.groupMembers, database.GroupMember{ - UserID: arg.UserID, - GroupID: groupID, - }) +func (q *FakeQuerier) UpdateProvisionerJobWithCompleteByID(_ context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error { + if err := validateDatabaseType(arg); err != nil { + return err } - return nil -} - -func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) { q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - link := database.UserLink{ - UserID: args.UserID, - LoginType: args.LoginType, - LinkedID: args.LinkedID, - OAuthAccessToken: args.OAuthAccessToken, - OAuthRefreshToken: args.OAuthRefreshToken, - OAuthExpiry: args.OAuthExpiry, + for index, job := range q.provisionerJobs { + if arg.ID != job.ID { + continue + } + job.UpdatedAt = arg.UpdatedAt + job.CompletedAt = arg.CompletedAt + job.Error = arg.Error + job.ErrorCode = arg.ErrorCode + q.provisionerJobs[index] = job + return nil } - - q.userLinks = append(q.userLinks, link) - - return link, nil + return sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspace(_ context.Context, arg database.InsertWorkspaceParams) (database.Workspace, error) { +func (q *FakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { if err := validateDatabaseType(arg); err != nil { - return database.Workspace{}, err + return database.Replica{}, err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - workspace := database.Workspace{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - OwnerID: arg.OwnerID, - OrganizationID: arg.OrganizationID, - TemplateID: arg.TemplateID, - Name: arg.Name, - AutostartSchedule: arg.AutostartSchedule, - Ttl: arg.Ttl, - LastUsedAt: arg.LastUsedAt, + for index, replica := range q.replicas { + if replica.ID != arg.ID { + continue + } + replica.Hostname = arg.Hostname + replica.StartedAt = arg.StartedAt + replica.StoppedAt = arg.StoppedAt + replica.UpdatedAt = arg.UpdatedAt + replica.RelayAddress = arg.RelayAddress + replica.RegionID = arg.RegionID + replica.Version = arg.Version + replica.Error = arg.Error + replica.DatabaseLatency = arg.DatabaseLatency + q.replicas[index] = replica + return replica, nil } - q.workspaces = append(q.workspaces, workspace) - return workspace, nil + return database.Replica{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) { +func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceAgent{}, err + return database.Template{}, err } q.mutex.Lock() defer q.mutex.Unlock() - agent := database.WorkspaceAgent{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - ResourceID: arg.ResourceID, - AuthToken: arg.AuthToken, - AuthInstanceID: arg.AuthInstanceID, - EnvironmentVariables: arg.EnvironmentVariables, - Name: arg.Name, - Architecture: arg.Architecture, - OperatingSystem: arg.OperatingSystem, - Directory: arg.Directory, - StartupScriptBehavior: arg.StartupScriptBehavior, - StartupScript: arg.StartupScript, - InstanceMetadata: arg.InstanceMetadata, - ResourceMetadata: arg.ResourceMetadata, - ConnectionTimeoutSeconds: arg.ConnectionTimeoutSeconds, - TroubleshootingURL: arg.TroubleshootingURL, - MOTDFile: arg.MOTDFile, - LifecycleState: database.WorkspaceAgentLifecycleStateCreated, - ShutdownScript: arg.ShutdownScript, + for i, template := range q.templates { + if template.ID == arg.ID { + template.GroupACL = arg.GroupACL + template.UserACL = arg.UserACL + + q.templates[i] = template + return template.DeepCopy(), nil + } } - q.workspaceAgents = append(q.workspaceAgents, agent) - return agent, nil + return database.Template{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceAgentMetadata(_ context.Context, arg database.InsertWorkspaceAgentMetadataParams) error { +func (q *FakeQuerier) UpdateTemplateActiveVersionByID(_ context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - metadatum := database.WorkspaceAgentMetadatum{ - WorkspaceAgentID: arg.WorkspaceAgentID, - Script: arg.Script, - DisplayName: arg.DisplayName, - Key: arg.Key, - Timeout: arg.Timeout, - Interval: arg.Interval, + for index, template := range q.templates { + if template.ID != arg.ID { + continue + } + template.ActiveVersionID = arg.ActiveVersionID + template.UpdatedAt = arg.UpdatedAt + q.templates[index] = template + return nil } - - q.workspaceAgentMetadata = append(q.workspaceAgentMetadata, metadatum) - return nil + return sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceAgentStartupLogs(_ context.Context, arg database.InsertWorkspaceAgentStartupLogsParams) ([]database.WorkspaceAgentStartupLog, error) { +func (q *FakeQuerier) UpdateTemplateDeletedByID(_ context.Context, arg database.UpdateTemplateDeletedByIDParams) error { if err := validateDatabaseType(arg); err != nil { - return nil, err + return err } q.mutex.Lock() defer q.mutex.Unlock() - logs := []database.WorkspaceAgentStartupLog{} - id := int64(0) - if len(q.workspaceAgentLogs) > 0 { - id = q.workspaceAgentLogs[len(q.workspaceAgentLogs)-1].ID - } - outputLength := int32(0) - for index, output := range arg.Output { - id++ - logs = append(logs, database.WorkspaceAgentStartupLog{ - ID: id, - AgentID: arg.AgentID, - CreatedAt: arg.CreatedAt[index], - Level: arg.Level[index], - Output: output, - }) - outputLength += int32(len(output)) - } - for index, agent := range q.workspaceAgents { - if agent.ID != arg.AgentID { + for index, template := range q.templates { + if template.ID != arg.ID { continue } - // Greater than 1MB, same as the PostgreSQL constraint! - if agent.StartupLogsLength+outputLength > (1 << 20) { - return nil, &pq.Error{ - Constraint: "max_startup_logs_length", - Table: "workspace_agents", - } - } - agent.StartupLogsLength += outputLength - q.workspaceAgents[index] = agent - break + template.Deleted = arg.Deleted + template.UpdatedAt = arg.UpdatedAt + q.templates[index] = template + return nil } - q.workspaceAgentLogs = append(q.workspaceAgentLogs, logs...) - return logs, nil + return sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceAgentStat(_ context.Context, p database.InsertWorkspaceAgentStatParams) (database.WorkspaceAgentStat, error) { - if err := validateDatabaseType(p); err != nil { - return database.WorkspaceAgentStat{}, err +func (q *FakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) { + if err := validateDatabaseType(arg); err != nil { + return database.Template{}, err } q.mutex.Lock() defer q.mutex.Unlock() - stat := database.WorkspaceAgentStat{ - ID: p.ID, - CreatedAt: p.CreatedAt, - WorkspaceID: p.WorkspaceID, - AgentID: p.AgentID, - UserID: p.UserID, - ConnectionsByProto: p.ConnectionsByProto, - ConnectionCount: p.ConnectionCount, - RxPackets: p.RxPackets, - RxBytes: p.RxBytes, - TxPackets: p.TxPackets, - TxBytes: p.TxBytes, - TemplateID: p.TemplateID, - SessionCountVSCode: p.SessionCountVSCode, - SessionCountJetBrains: p.SessionCountJetBrains, - SessionCountReconnectingPTY: p.SessionCountReconnectingPTY, - SessionCountSSH: p.SessionCountSSH, - ConnectionMedianLatencyMS: p.ConnectionMedianLatencyMS, + for idx, tpl := range q.templates { + if tpl.ID != arg.ID { + continue + } + tpl.UpdatedAt = database.Now() + tpl.Name = arg.Name + tpl.DisplayName = arg.DisplayName + tpl.Description = arg.Description + tpl.Icon = arg.Icon + q.templates[idx] = tpl + return tpl.DeepCopy(), nil } - q.workspaceAgentStats = append(q.workspaceAgentStats, stat) - return stat, nil + + return database.Template{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceApp(_ context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) { +func (q *FakeQuerier) UpdateTemplateScheduleByID(_ context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) { if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceApp{}, err + return database.Template{}, err } q.mutex.Lock() defer q.mutex.Unlock() - if arg.SharingLevel == "" { - arg.SharingLevel = database.AppSharingLevelOwner + for idx, tpl := range q.templates { + if tpl.ID != arg.ID { + continue + } + tpl.AllowUserAutostart = arg.AllowUserAutostart + tpl.AllowUserAutostop = arg.AllowUserAutostop + tpl.UpdatedAt = database.Now() + tpl.DefaultTTL = arg.DefaultTTL + tpl.MaxTTL = arg.MaxTTL + tpl.FailureTTL = arg.FailureTTL + tpl.InactivityTTL = arg.InactivityTTL + tpl.LockedTTL = arg.LockedTTL + q.templates[idx] = tpl + return tpl.DeepCopy(), nil } - // nolint:gosimple - workspaceApp := database.WorkspaceApp{ - ID: arg.ID, - AgentID: arg.AgentID, - CreatedAt: arg.CreatedAt, - Slug: arg.Slug, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - Command: arg.Command, - Url: arg.Url, - External: arg.External, - Subdomain: arg.Subdomain, - SharingLevel: arg.SharingLevel, - HealthcheckUrl: arg.HealthcheckUrl, - HealthcheckInterval: arg.HealthcheckInterval, - HealthcheckThreshold: arg.HealthcheckThreshold, - Health: arg.Health, - } - q.workspaceApps = append(q.workspaceApps, workspaceApp) - return workspaceApp, nil + return database.Template{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.InsertWorkspaceBuildParams) (database.WorkspaceBuild, error) { +func (q *FakeQuerier) UpdateTemplateVersionByID(_ context.Context, arg database.UpdateTemplateVersionByIDParams) (database.TemplateVersion, error) { if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceBuild{}, err + return database.TemplateVersion{}, err } q.mutex.Lock() defer q.mutex.Unlock() - workspaceBuild := database.WorkspaceBuild{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - WorkspaceID: arg.WorkspaceID, - TemplateVersionID: arg.TemplateVersionID, - BuildNumber: arg.BuildNumber, - Transition: arg.Transition, - InitiatorID: arg.InitiatorID, - JobID: arg.JobID, - ProvisionerState: arg.ProvisionerState, - Deadline: arg.Deadline, - Reason: arg.Reason, + for index, templateVersion := range q.templateVersions { + if templateVersion.ID != arg.ID { + continue + } + templateVersion.TemplateID = arg.TemplateID + templateVersion.UpdatedAt = arg.UpdatedAt + templateVersion.Name = arg.Name + q.templateVersions[index] = templateVersion + return templateVersion, nil } - q.workspaceBuilds = append(q.workspaceBuilds, workspaceBuild) - return workspaceBuild, nil + return database.TemplateVersion{}, sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceBuildParameters(_ context.Context, arg database.InsertWorkspaceBuildParametersParams) error { +func (q *FakeQuerier) UpdateTemplateVersionDescriptionByJobID(_ context.Context, arg database.UpdateTemplateVersionDescriptionByJobIDParams) error { if err := validateDatabaseType(arg); err != nil { return err } @@ -4330,233 +4320,279 @@ func (q *FakeQuerier) InsertWorkspaceBuildParameters(_ context.Context, arg data q.mutex.Lock() defer q.mutex.Unlock() - for index, name := range arg.Name { - q.workspaceBuildParameters = append(q.workspaceBuildParameters, database.WorkspaceBuildParameter{ - WorkspaceBuildID: arg.WorkspaceBuildID, - Name: name, - Value: arg.Value[index], - }) + for index, templateVersion := range q.templateVersions { + if templateVersion.JobID != arg.JobID { + continue + } + templateVersion.Readme = arg.Readme + templateVersion.UpdatedAt = arg.UpdatedAt + q.templateVersions[index] = templateVersion + return nil } - return nil + return sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceProxy(_ context.Context, arg database.InsertWorkspaceProxyParams) (database.WorkspaceProxy, error) { +func (q *FakeQuerier) UpdateTemplateVersionGitAuthProvidersByJobID(_ context.Context, arg database.UpdateTemplateVersionGitAuthProvidersByJobIDParams) error { + if err := validateDatabaseType(arg); err != nil { + return err + } + q.mutex.Lock() defer q.mutex.Unlock() - for _, p := range q.workspaceProxies { - if !p.Deleted && p.Name == arg.Name { - return database.WorkspaceProxy{}, errDuplicateKey + for index, templateVersion := range q.templateVersions { + if templateVersion.JobID != arg.JobID { + continue } + templateVersion.GitAuthProviders = arg.GitAuthProviders + templateVersion.UpdatedAt = arg.UpdatedAt + q.templateVersions[index] = templateVersion + return nil } + return sql.ErrNoRows +} - p := database.WorkspaceProxy{ - ID: arg.ID, - Name: arg.Name, - DisplayName: arg.DisplayName, - Icon: arg.Icon, - TokenHashedSecret: arg.TokenHashedSecret, - CreatedAt: arg.CreatedAt, - UpdatedAt: arg.UpdatedAt, - Deleted: false, +func (q *FakeQuerier) UpdateUserDeletedByID(_ context.Context, params database.UpdateUserDeletedByIDParams) error { + if err := validateDatabaseType(params); err != nil { + return err } - q.workspaceProxies = append(q.workspaceProxies, p) - return p, nil + + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, u := range q.users { + if u.ID == params.ID { + u.Deleted = params.Deleted + q.users[i] = u + // NOTE: In the real world, this is done by a trigger. + i := 0 + for { + if i >= len(q.apiKeys) { + break + } + k := q.apiKeys[i] + if k.UserID == u.ID { + q.apiKeys[i] = q.apiKeys[len(q.apiKeys)-1] + q.apiKeys = q.apiKeys[:len(q.apiKeys)-1] + // We removed an element, so decrement + i-- + } + i++ + } + return nil + } + } + return sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) { +func (q *FakeQuerier) UpdateUserHashedPassword(_ context.Context, arg database.UpdateUserHashedPasswordParams) error { if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceResource{}, err + return err } q.mutex.Lock() defer q.mutex.Unlock() - //nolint:gosimple - resource := database.WorkspaceResource{ - ID: arg.ID, - CreatedAt: arg.CreatedAt, - JobID: arg.JobID, - Transition: arg.Transition, - Type: arg.Type, - Name: arg.Name, - Hide: arg.Hide, - Icon: arg.Icon, - DailyCost: arg.DailyCost, + for i, user := range q.users { + if user.ID != arg.ID { + continue + } + user.HashedPassword = arg.HashedPassword + q.users[i] = user + return nil } - q.workspaceResources = append(q.workspaceResources, resource) - return resource, nil + return sql.ErrNoRows } -func (q *FakeQuerier) InsertWorkspaceResourceMetadata(_ context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) { +func (q *FakeQuerier) UpdateUserLastSeenAt(_ context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { if err := validateDatabaseType(arg); err != nil { - return nil, err + return database.User{}, err } q.mutex.Lock() defer q.mutex.Unlock() - metadata := make([]database.WorkspaceResourceMetadatum, 0) - id := int64(1) - if len(q.workspaceResourceMetadata) > 0 { - id = q.workspaceResourceMetadata[len(q.workspaceResourceMetadata)-1].ID + for index, user := range q.users { + if user.ID != arg.ID { + continue + } + user.LastSeenAt = arg.LastSeenAt + user.UpdatedAt = arg.UpdatedAt + q.users[index] = user + return user, nil } - for index, key := range arg.Key { - id++ - value := arg.Value[index] - metadata = append(metadata, database.WorkspaceResourceMetadatum{ - ID: id, - WorkspaceResourceID: arg.WorkspaceResourceID, - Key: key, - Value: sql.NullString{ - String: value, - Valid: value != "", - }, - Sensitive: arg.Sensitive[index], - }) + return database.User{}, sql.ErrNoRows +} + +func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) { + if err := validateDatabaseType(params); err != nil { + return database.UserLink{}, err + } + + q.mutex.Lock() + defer q.mutex.Unlock() + + for i, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + link.OAuthAccessToken = params.OAuthAccessToken + link.OAuthRefreshToken = params.OAuthRefreshToken + link.OAuthExpiry = params.OAuthExpiry + + q.userLinks[i] = link + return link, nil + } } - q.workspaceResourceMetadata = append(q.workspaceResourceMetadata, metadata...) - return metadata, nil + + return database.UserLink{}, sql.ErrNoRows } -func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.RegisterWorkspaceProxyParams) (database.WorkspaceProxy, error) { +func (q *FakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) { + if err := validateDatabaseType(params); err != nil { + return database.UserLink{}, err + } + q.mutex.Lock() defer q.mutex.Unlock() - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Url = arg.Url - p.WildcardHostname = arg.WildcardHostname - p.UpdatedAt = database.Now() - q.workspaceProxies[i] = p - return p, nil + for i, link := range q.userLinks { + if link.UserID == params.UserID && link.LoginType == params.LoginType { + link.LinkedID = params.LinkedID + + q.userLinks[i] = link + return link, nil } } - return database.WorkspaceProxy{}, sql.ErrNoRows -} -func (*FakeQuerier) TryAcquireLock(_ context.Context, _ int64) (bool, error) { - return false, xerrors.New("TryAcquireLock must only be called within a transaction") + return database.UserLink{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error { +func (q *FakeQuerier) UpdateUserLoginType(_ context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) { if err := validateDatabaseType(arg); err != nil { - return err + return database.User{}, err } q.mutex.Lock() defer q.mutex.Unlock() - for index, apiKey := range q.apiKeys { - if apiKey.ID != arg.ID { - continue + for i, u := range q.users { + if u.ID == arg.UserID { + u.LoginType = arg.NewLoginType + if arg.NewLoginType != database.LoginTypePassword { + u.HashedPassword = []byte{} + } + q.users[i] = u + return u, nil } - apiKey.LastUsed = arg.LastUsed - apiKey.ExpiresAt = arg.ExpiresAt - apiKey.IPAddress = arg.IPAddress - q.apiKeys[index] = apiKey - return nil } - return sql.ErrNoRows + return database.User{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGitAuthLinkParams) (database.GitAuthLink, error) { +func (q *FakeQuerier) UpdateUserProfile(_ context.Context, arg database.UpdateUserProfileParams) (database.User, error) { if err := validateDatabaseType(arg); err != nil { - return database.GitAuthLink{}, err + return database.User{}, err } q.mutex.Lock() defer q.mutex.Unlock() - for index, gitAuthLink := range q.gitAuthLinks { - if gitAuthLink.ProviderID != arg.ProviderID { - continue - } - if gitAuthLink.UserID != arg.UserID { + + for index, user := range q.users { + if user.ID != arg.ID { continue } - gitAuthLink.UpdatedAt = arg.UpdatedAt - gitAuthLink.OAuthAccessToken = arg.OAuthAccessToken - gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken - gitAuthLink.OAuthExpiry = arg.OAuthExpiry - q.gitAuthLinks[index] = gitAuthLink - - return gitAuthLink, nil + user.Email = arg.Email + user.Username = arg.Username + user.AvatarURL = arg.AvatarURL + q.users[index] = user + return user, nil } - return database.GitAuthLink{}, sql.ErrNoRows + return database.User{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitSSHKeyParams) (database.GitSSHKey, error) { +func (q *FakeQuerier) UpdateUserRoles(_ context.Context, arg database.UpdateUserRolesParams) (database.User, error) { if err := validateDatabaseType(arg); err != nil { - return database.GitSSHKey{}, err + return database.User{}, err } q.mutex.Lock() defer q.mutex.Unlock() - for index, key := range q.gitSSHKey { - if key.UserID != arg.UserID { + for index, user := range q.users { + if user.ID != arg.ID { continue } - key.UpdatedAt = arg.UpdatedAt - key.PrivateKey = arg.PrivateKey - key.PublicKey = arg.PublicKey - q.gitSSHKey[index] = key - return key, nil + + // Set new roles + user.RBACRoles = arg.GrantedRoles + // Remove duplicates and sort + uniqueRoles := make([]string, 0, len(user.RBACRoles)) + exist := make(map[string]struct{}) + for _, r := range user.RBACRoles { + if _, ok := exist[r]; ok { + continue + } + exist[r] = struct{}{} + uniqueRoles = append(uniqueRoles, r) + } + sort.Strings(uniqueRoles) + user.RBACRoles = uniqueRoles + + q.users[index] = user + return user, nil } - return database.GitSSHKey{}, sql.ErrNoRows + return database.User{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) { +func (q *FakeQuerier) UpdateUserStatus(_ context.Context, arg database.UpdateUserStatusParams) (database.User, error) { if err := validateDatabaseType(arg); err != nil { - return database.Group{}, err + return database.User{}, err } q.mutex.Lock() defer q.mutex.Unlock() - for i, group := range q.groups { - if group.ID == arg.ID { - group.Name = arg.Name - group.AvatarURL = arg.AvatarURL - group.QuotaAllowance = arg.QuotaAllowance - q.groups[i] = group - return group, nil + for index, user := range q.users { + if user.ID != arg.ID { + continue } + user.Status = arg.Status + user.UpdatedAt = arg.UpdatedAt + q.users[index] = user + return user, nil } - return database.Group{}, sql.ErrNoRows + return database.User{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateMemberRoles(_ context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) { +func (q *FakeQuerier) UpdateWorkspace(_ context.Context, arg database.UpdateWorkspaceParams) (database.Workspace, error) { if err := validateDatabaseType(arg); err != nil { - return database.OrganizationMember{}, err + return database.Workspace{}, err } q.mutex.Lock() defer q.mutex.Unlock() - for i, mem := range q.organizationMembers { - if mem.UserID == arg.UserID && mem.OrganizationID == arg.OrgID { - uniqueRoles := make([]string, 0, len(arg.GrantedRoles)) - exist := make(map[string]struct{}) - for _, r := range arg.GrantedRoles { - if _, ok := exist[r]; ok { - continue - } - exist[r] = struct{}{} - uniqueRoles = append(uniqueRoles, r) + for i, workspace := range q.workspaces { + if workspace.Deleted || workspace.ID != arg.ID { + continue + } + for _, other := range q.workspaces { + if other.Deleted || other.ID == workspace.ID || workspace.OwnerID != other.OwnerID { + continue + } + if other.Name == arg.Name { + return database.Workspace{}, errDuplicateKey } - sort.Strings(uniqueRoles) - - mem.Roles = uniqueRoles - q.organizationMembers[i] = mem - return mem, nil } + + workspace.Name = arg.Name + q.workspaces[i] = workspace + + return workspace, nil } - return database.OrganizationMember{}, sql.ErrNoRows + return database.Workspace{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error { +func (q *FakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) error { if err := validateDatabaseType(arg); err != nil { return err } @@ -4564,108 +4600,102 @@ func (q *FakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.U q.mutex.Lock() defer q.mutex.Unlock() - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { + for index, agent := range q.workspaceAgents { + if agent.ID != arg.ID { continue } - job.UpdatedAt = arg.UpdatedAt - q.provisionerJobs[index] = job + agent.FirstConnectedAt = arg.FirstConnectedAt + agent.LastConnectedAt = arg.LastConnectedAt + agent.DisconnectedAt = arg.DisconnectedAt + agent.UpdatedAt = arg.UpdatedAt + q.workspaceAgents[index] = agent return nil } return sql.ErrNoRows } -func (q *FakeQuerier) UpdateProvisionerJobWithCancelByID(_ context.Context, arg database.UpdateProvisionerJobWithCancelByIDParams) error { +func (q *FakeQuerier) UpdateWorkspaceAgentLifecycleStateByID(_ context.Context, arg database.UpdateWorkspaceAgentLifecycleStateByIDParams) error { if err := validateDatabaseType(arg); err != nil { return err } q.mutex.Lock() defer q.mutex.Unlock() - - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue + for i, agent := range q.workspaceAgents { + if agent.ID == arg.ID { + agent.LifecycleState = arg.LifecycleState + agent.StartedAt = arg.StartedAt + agent.ReadyAt = arg.ReadyAt + q.workspaceAgents[i] = agent + return nil } - job.CanceledAt = arg.CanceledAt - job.CompletedAt = arg.CompletedAt - q.provisionerJobs[index] = job - return nil } return sql.ErrNoRows } -func (q *FakeQuerier) UpdateProvisionerJobWithCompleteByID(_ context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - +func (q *FakeQuerier) UpdateWorkspaceAgentMetadata(_ context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error { q.mutex.Lock() defer q.mutex.Unlock() - for index, job := range q.provisionerJobs { - if arg.ID != job.ID { - continue + //nolint:gosimple + updated := database.WorkspaceAgentMetadatum{ + WorkspaceAgentID: arg.WorkspaceAgentID, + Key: arg.Key, + Value: arg.Value, + Error: arg.Error, + CollectedAt: arg.CollectedAt, + } + + for i, m := range q.workspaceAgentMetadata { + if m.WorkspaceAgentID == arg.WorkspaceAgentID && m.Key == arg.Key { + q.workspaceAgentMetadata[i] = updated + return nil } - job.UpdatedAt = arg.UpdatedAt - job.CompletedAt = arg.CompletedAt - job.Error = arg.Error - job.ErrorCode = arg.ErrorCode - q.provisionerJobs[index] = job - return nil } - return sql.ErrNoRows + + return nil } -func (q *FakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) { +func (q *FakeQuerier) UpdateWorkspaceAgentStartupByID(_ context.Context, arg database.UpdateWorkspaceAgentStartupByIDParams) error { if err := validateDatabaseType(arg); err != nil { - return database.Replica{}, err + return err } q.mutex.Lock() defer q.mutex.Unlock() - for index, replica := range q.replicas { - if replica.ID != arg.ID { + for index, agent := range q.workspaceAgents { + if agent.ID != arg.ID { continue } - replica.Hostname = arg.Hostname - replica.StartedAt = arg.StartedAt - replica.StoppedAt = arg.StoppedAt - replica.UpdatedAt = arg.UpdatedAt - replica.RelayAddress = arg.RelayAddress - replica.RegionID = arg.RegionID - replica.Version = arg.Version - replica.Error = arg.Error - replica.DatabaseLatency = arg.DatabaseLatency - q.replicas[index] = replica - return replica, nil + + agent.Version = arg.Version + agent.ExpandedDirectory = arg.ExpandedDirectory + agent.Subsystem = arg.Subsystem + q.workspaceAgents[index] = agent + return nil } - return database.Replica{}, sql.ErrNoRows + return sql.ErrNoRows } -func (q *FakeQuerier) UpdateTemplateACLByID(_ context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) { +func (q *FakeQuerier) UpdateWorkspaceAgentStartupLogOverflowByID(_ context.Context, arg database.UpdateWorkspaceAgentStartupLogOverflowByIDParams) error { if err := validateDatabaseType(arg); err != nil { - return database.Template{}, err + return err } q.mutex.Lock() defer q.mutex.Unlock() - - for i, template := range q.templates { - if template.ID == arg.ID { - template.GroupACL = arg.GroupACL - template.UserACL = arg.UserACL - - q.templates[i] = template - return template.DeepCopy(), nil + for i, agent := range q.workspaceAgents { + if agent.ID == arg.ID { + agent.StartupLogsOverflowed = arg.StartupLogsOverflowed + q.workspaceAgents[i] = agent + return nil } } - - return database.Template{}, sql.ErrNoRows + return sql.ErrNoRows } -func (q *FakeQuerier) UpdateTemplateActiveVersionByID(_ context.Context, arg database.UpdateTemplateActiveVersionByIDParams) error { +func (q *FakeQuerier) UpdateWorkspaceAppHealthByID(_ context.Context, arg database.UpdateWorkspaceAppHealthByIDParams) error { if err := validateDatabaseType(arg); err != nil { return err } @@ -4673,19 +4703,18 @@ func (q *FakeQuerier) UpdateTemplateActiveVersionByID(_ context.Context, arg dat q.mutex.Lock() defer q.mutex.Unlock() - for index, template := range q.templates { - if template.ID != arg.ID { + for index, app := range q.workspaceApps { + if app.ID != arg.ID { continue } - template.ActiveVersionID = arg.ActiveVersionID - template.UpdatedAt = arg.UpdatedAt - q.templates[index] = template + app.Health = arg.Health + q.workspaceApps[index] = app return nil } return sql.ErrNoRows } -func (q *FakeQuerier) UpdateTemplateDeletedByID(_ context.Context, arg database.UpdateTemplateDeletedByIDParams) error { +func (q *FakeQuerier) UpdateWorkspaceAutostart(_ context.Context, arg database.UpdateWorkspaceAutostartParams) error { if err := validateDatabaseType(arg); err != nil { return err } @@ -4693,91 +4722,79 @@ func (q *FakeQuerier) UpdateTemplateDeletedByID(_ context.Context, arg database. q.mutex.Lock() defer q.mutex.Unlock() - for index, template := range q.templates { - if template.ID != arg.ID { + for index, workspace := range q.workspaces { + if workspace.ID != arg.ID { continue } - template.Deleted = arg.Deleted - template.UpdatedAt = arg.UpdatedAt - q.templates[index] = template + workspace.AutostartSchedule = arg.AutostartSchedule + q.workspaces[index] = workspace return nil } + return sql.ErrNoRows } -func (q *FakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.UpdateTemplateMetaByIDParams) (database.Template, error) { +func (q *FakeQuerier) UpdateWorkspaceBuildByID(_ context.Context, arg database.UpdateWorkspaceBuildByIDParams) (database.WorkspaceBuild, error) { if err := validateDatabaseType(arg); err != nil { - return database.Template{}, err + return database.WorkspaceBuild{}, err } q.mutex.Lock() defer q.mutex.Unlock() - for idx, tpl := range q.templates { - if tpl.ID != arg.ID { + for index, workspaceBuild := range q.workspaceBuilds { + if workspaceBuild.ID != arg.ID { continue } - tpl.UpdatedAt = database.Now() - tpl.Name = arg.Name - tpl.DisplayName = arg.DisplayName - tpl.Description = arg.Description - tpl.Icon = arg.Icon - q.templates[idx] = tpl - return tpl.DeepCopy(), nil + workspaceBuild.UpdatedAt = arg.UpdatedAt + workspaceBuild.ProvisionerState = arg.ProvisionerState + workspaceBuild.Deadline = arg.Deadline + workspaceBuild.MaxDeadline = arg.MaxDeadline + q.workspaceBuilds[index] = workspaceBuild + return workspaceBuild, nil } - - return database.Template{}, sql.ErrNoRows + return database.WorkspaceBuild{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateTemplateScheduleByID(_ context.Context, arg database.UpdateTemplateScheduleByIDParams) (database.Template, error) { +func (q *FakeQuerier) UpdateWorkspaceBuildCostByID(_ context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) (database.WorkspaceBuild, error) { if err := validateDatabaseType(arg); err != nil { - return database.Template{}, err + return database.WorkspaceBuild{}, err } q.mutex.Lock() defer q.mutex.Unlock() - for idx, tpl := range q.templates { - if tpl.ID != arg.ID { + for index, workspaceBuild := range q.workspaceBuilds { + if workspaceBuild.ID != arg.ID { continue } - tpl.AllowUserAutostart = arg.AllowUserAutostart - tpl.AllowUserAutostop = arg.AllowUserAutostop - tpl.UpdatedAt = database.Now() - tpl.DefaultTTL = arg.DefaultTTL - tpl.MaxTTL = arg.MaxTTL - tpl.FailureTTL = arg.FailureTTL - tpl.InactivityTTL = arg.InactivityTTL - tpl.LockedTTL = arg.LockedTTL - q.templates[idx] = tpl - return tpl.DeepCopy(), nil + workspaceBuild.DailyCost = arg.DailyCost + q.workspaceBuilds[index] = workspaceBuild + return workspaceBuild, nil } - - return database.Template{}, sql.ErrNoRows + return database.WorkspaceBuild{}, sql.ErrNoRows } -func (q *FakeQuerier) UpdateTemplateVersionByID(_ context.Context, arg database.UpdateTemplateVersionByIDParams) (database.TemplateVersion, error) { +func (q *FakeQuerier) UpdateWorkspaceDeletedByID(_ context.Context, arg database.UpdateWorkspaceDeletedByIDParams) error { if err := validateDatabaseType(arg); err != nil { - return database.TemplateVersion{}, err + return err } q.mutex.Lock() defer q.mutex.Unlock() - for index, templateVersion := range q.templateVersions { - if templateVersion.ID != arg.ID { + for index, workspace := range q.workspaces { + if workspace.ID != arg.ID { continue } - templateVersion.TemplateID = arg.TemplateID - templateVersion.UpdatedAt = arg.UpdatedAt - templateVersion.Name = arg.Name - q.templateVersions[index] = templateVersion - return templateVersion, nil + workspace.Deleted = arg.Deleted + q.workspaces[index] = workspace + return nil } - return database.TemplateVersion{}, sql.ErrNoRows + return sql.ErrNoRows } -func (q *FakeQuerier) UpdateTemplateVersionDescriptionByJobID(_ context.Context, arg database.UpdateTemplateVersionDescriptionByJobIDParams) error { +func (q *FakeQuerier) UpdateWorkspaceLastUsedAt(_ context.Context, arg database.UpdateWorkspaceLastUsedAtParams) error { if err := validateDatabaseType(arg); err != nil { return err } @@ -4785,19 +4802,19 @@ func (q *FakeQuerier) UpdateTemplateVersionDescriptionByJobID(_ context.Context, q.mutex.Lock() defer q.mutex.Unlock() - for index, templateVersion := range q.templateVersions { - if templateVersion.JobID != arg.JobID { + for index, workspace := range q.workspaces { + if workspace.ID != arg.ID { continue } - templateVersion.Readme = arg.Readme - templateVersion.UpdatedAt = arg.UpdatedAt - q.templateVersions[index] = templateVersion + workspace.LastUsedAt = arg.LastUsedAt + q.workspaces[index] = workspace return nil } + return sql.ErrNoRows } -func (q *FakeQuerier) UpdateTemplateVersionGitAuthProvidersByJobID(_ context.Context, arg database.UpdateTemplateVersionGitAuthProvidersByJobIDParams) error { +func (q *FakeQuerier) UpdateWorkspaceLockedAt(_ context.Context, arg database.UpdateWorkspaceLockedAtParams) error { if err := validateDatabaseType(arg); err != nil { return err } @@ -4805,52 +4822,60 @@ func (q *FakeQuerier) UpdateTemplateVersionGitAuthProvidersByJobID(_ context.Con q.mutex.Lock() defer q.mutex.Unlock() - for index, templateVersion := range q.templateVersions { - if templateVersion.JobID != arg.JobID { + for index, workspace := range q.workspaces { + if workspace.ID != arg.ID { continue } - templateVersion.GitAuthProviders = arg.GitAuthProviders - templateVersion.UpdatedAt = arg.UpdatedAt - q.templateVersions[index] = templateVersion + workspace.LockedAt = arg.LockedAt + workspace.LastUsedAt = database.Now() + q.workspaces[index] = workspace return nil } + return sql.ErrNoRows } -func (q *FakeQuerier) UpdateUserDeletedByID(_ context.Context, params database.UpdateUserDeletedByIDParams) error { - if err := validateDatabaseType(params); err != nil { - return err +func (q *FakeQuerier) UpdateWorkspaceProxy(_ context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, p := range q.workspaceProxies { + if p.Name == arg.Name && p.ID != arg.ID { + return database.WorkspaceProxy{}, errDuplicateKey + } + } + + for i, p := range q.workspaceProxies { + if p.ID == arg.ID { + p.Name = arg.Name + p.DisplayName = arg.DisplayName + p.Icon = arg.Icon + if len(p.TokenHashedSecret) > 0 { + p.TokenHashedSecret = arg.TokenHashedSecret + } + q.workspaceProxies[i] = p + return p, nil + } } + return database.WorkspaceProxy{}, sql.ErrNoRows +} +func (q *FakeQuerier) UpdateWorkspaceProxyDeleted(_ context.Context, arg database.UpdateWorkspaceProxyDeletedParams) error { q.mutex.Lock() defer q.mutex.Unlock() - for i, u := range q.users { - if u.ID == params.ID { - u.Deleted = params.Deleted - q.users[i] = u - // NOTE: In the real world, this is done by a trigger. - i := 0 - for { - if i >= len(q.apiKeys) { - break - } - k := q.apiKeys[i] - if k.UserID == u.ID { - q.apiKeys[i] = q.apiKeys[len(q.apiKeys)-1] - q.apiKeys = q.apiKeys[:len(q.apiKeys)-1] - // We removed an element, so decrement - i-- - } - i++ - } + for i, p := range q.workspaceProxies { + if p.ID == arg.ID { + p.Deleted = arg.Deleted + p.UpdatedAt = database.Now() + q.workspaceProxies[i] = p return nil } } return sql.ErrNoRows } -func (q *FakeQuerier) UpdateUserHashedPassword(_ context.Context, arg database.UpdateUserHashedPasswordParams) error { +func (q *FakeQuerier) UpdateWorkspaceTTL(_ context.Context, arg database.UpdateWorkspaceTTLParams) error { if err := validateDatabaseType(arg); err != nil { return err } @@ -4858,582 +4883,557 @@ func (q *FakeQuerier) UpdateUserHashedPassword(_ context.Context, arg database.U q.mutex.Lock() defer q.mutex.Unlock() - for i, user := range q.users { - if user.ID != arg.ID { + for index, workspace := range q.workspaces { + if workspace.ID != arg.ID { continue } - user.HashedPassword = arg.HashedPassword - q.users[i] = user + workspace.Ttl = arg.Ttl + q.workspaces[index] = workspace return nil } + return sql.ErrNoRows } -func (q *FakeQuerier) UpdateUserLastSeenAt(_ context.Context, arg database.UpdateUserLastSeenAtParams) (database.User, error) { +func (q *FakeQuerier) UpdateWorkspaceTTLToBeWithinTemplateMax(_ context.Context, arg database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams) error { if err := validateDatabaseType(arg); err != nil { - return database.User{}, err + return err } q.mutex.Lock() defer q.mutex.Unlock() - for index, user := range q.users { - if user.ID != arg.ID { + for index, workspace := range q.workspaces { + if workspace.TemplateID != arg.TemplateID || !workspace.Ttl.Valid || workspace.Ttl.Int64 < arg.TemplateMaxTTL { continue } - user.LastSeenAt = arg.LastSeenAt - user.UpdatedAt = arg.UpdatedAt - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } - - q.mutex.Lock() - defer q.mutex.Unlock() - - for i, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - link.OAuthAccessToken = params.OAuthAccessToken - link.OAuthRefreshToken = params.OAuthRefreshToken - link.OAuthExpiry = params.OAuthExpiry - q.userLinks[i] = link - return link, nil - } + workspace.Ttl = sql.NullInt64{Int64: arg.TemplateMaxTTL, Valid: true} + q.workspaces[index] = workspace } - return database.UserLink{}, sql.ErrNoRows + return nil } -func (q *FakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) { - if err := validateDatabaseType(params); err != nil { - return database.UserLink{}, err - } - +func (q *FakeQuerier) UpsertAppSecurityKey(_ context.Context, data string) error { q.mutex.Lock() defer q.mutex.Unlock() - for i, link := range q.userLinks { - if link.UserID == params.UserID && link.LoginType == params.LoginType { - link.LinkedID = params.LinkedID - - q.userLinks[i] = link - return link, nil - } - } - - return database.UserLink{}, sql.ErrNoRows + q.appSecurityKey = data + return nil } -func (q *FakeQuerier) UpdateUserLoginType(_ context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } +func (q *FakeQuerier) UpsertDefaultProxy(_ context.Context, arg database.UpsertDefaultProxyParams) error { + q.defaultProxyDisplayName = arg.DisplayName + q.defaultProxyIconURL = arg.IconUrl + return nil +} +func (q *FakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error { q.mutex.Lock() defer q.mutex.Unlock() - for i, u := range q.users { - if u.ID == arg.UserID { - u.LoginType = arg.NewLoginType - if arg.NewLoginType != database.LoginTypePassword { - u.HashedPassword = []byte{} - } - q.users[i] = u - return u, nil - } - } - return database.User{}, sql.ErrNoRows + q.lastUpdateCheck = []byte(data) + return nil } -func (q *FakeQuerier) UpdateUserProfile(_ context.Context, arg database.UpdateUserProfileParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } +func (q *FakeQuerier) UpsertLogoURL(_ context.Context, data string) error { + q.mutex.RLock() + defer q.mutex.RUnlock() + + q.logoURL = data + return nil +} +func (q *FakeQuerier) UpsertOAuthSigningKey(_ context.Context, value string) error { q.mutex.Lock() defer q.mutex.Unlock() - for index, user := range q.users { - if user.ID != arg.ID { - continue - } - user.Email = arg.Email - user.Username = arg.Username - user.AvatarURL = arg.AvatarURL - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows + q.oauthSigningKey = value + return nil } -func (q *FakeQuerier) UpdateUserRoles(_ context.Context, arg database.UpdateUserRolesParams) (database.User, error) { - if err := validateDatabaseType(arg); err != nil { - return database.User{}, err - } +func (q *FakeQuerier) UpsertServiceBanner(_ context.Context, data string) error { + q.mutex.RLock() + defer q.mutex.RUnlock() - q.mutex.Lock() - defer q.mutex.Unlock() + q.serviceBanner = []byte(data) + return nil +} - for index, user := range q.users { - if user.ID != arg.ID { - continue - } +func (*FakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + return database.TailnetAgent{}, ErrUnimplemented +} - // Set new roles - user.RBACRoles = arg.GrantedRoles - // Remove duplicates and sort - uniqueRoles := make([]string, 0, len(user.RBACRoles)) - exist := make(map[string]struct{}) - for _, r := range user.RBACRoles { - if _, ok := exist[r]; ok { - continue - } - exist[r] = struct{}{} - uniqueRoles = append(uniqueRoles, r) - } - sort.Strings(uniqueRoles) - user.RBACRoles = uniqueRoles +func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { + return database.TailnetClient{}, ErrUnimplemented +} - q.users[index] = user - return user, nil - } - return database.User{}, sql.ErrNoRows +func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { + return database.TailnetCoordinator{}, ErrUnimplemented } -func (q *FakeQuerier) UpdateUserStatus(_ context.Context, arg database.UpdateUserStatusParams) (database.User, error) { +func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { if err := validateDatabaseType(arg); err != nil { - return database.User{}, err + return nil, err } - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() - for index, user := range q.users { - if user.ID != arg.ID { - continue + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) + if err != nil { + return nil, err } - user.Status = arg.Status - user.UpdatedAt = arg.UpdatedAt - q.users[index] = user - return user, nil } - return database.User{}, sql.ErrNoRows -} -func (q *FakeQuerier) UpdateWorkspace(_ context.Context, arg database.UpdateWorkspaceParams) (database.Workspace, error) { - if err := validateDatabaseType(arg); err != nil { - return database.Workspace{}, err - } + var templates []database.Template + for _, template := range q.templates { + if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { + continue + } - q.mutex.Lock() - defer q.mutex.Unlock() + if template.Deleted != arg.Deleted { + continue + } + if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { + continue + } - for i, workspace := range q.workspaces { - if workspace.Deleted || workspace.ID != arg.ID { + if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { continue } - for _, other := range q.workspaces { - if other.Deleted || other.ID == workspace.ID || workspace.OwnerID != other.OwnerID { - continue + + if len(arg.IDs) > 0 { + match := false + for _, id := range arg.IDs { + if template.ID == id { + match = true + break + } } - if other.Name == arg.Name { - return database.Workspace{}, errDuplicateKey + if !match { + continue } } - - workspace.Name = arg.Name - q.workspaces[i] = workspace - - return workspace, nil + templates = append(templates, template.DeepCopy()) + } + if len(templates) > 0 { + slices.SortFunc(templates, func(i, j database.Template) bool { + if i.Name != j.Name { + return i.Name < j.Name + } + return i.ID.String() < j.ID.String() + }) + return templates, nil } - return database.Workspace{}, sql.ErrNoRows + return nil, sql.ErrNoRows } -func (q *FakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err +func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + var template database.Template + for _, t := range q.templates { + if t.ID == id { + template = t + break + } } - q.mutex.Lock() - defer q.mutex.Unlock() + if template.ID == uuid.Nil { + return nil, sql.ErrNoRows + } - for index, agent := range q.workspaceAgents { - if agent.ID != arg.ID { + groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) + for k, v := range template.GroupACL { + group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get group by ID: %w", err) + } + // We don't delete groups from the map if they + // get deleted so just skip. + if xerrors.Is(err, sql.ErrNoRows) { continue } - agent.FirstConnectedAt = arg.FirstConnectedAt - agent.LastConnectedAt = arg.LastConnectedAt - agent.DisconnectedAt = arg.DisconnectedAt - agent.UpdatedAt = arg.UpdatedAt - q.workspaceAgents[index] = agent - return nil + + groups = append(groups, database.TemplateGroup{ + Group: group, + Actions: v, + }) } - return sql.ErrNoRows + + return groups, nil } -func (q *FakeQuerier) UpdateWorkspaceAgentLifecycleStateByID(_ context.Context, arg database.UpdateWorkspaceAgentLifecycleStateByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } +func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() - q.mutex.Lock() - defer q.mutex.Unlock() - for i, agent := range q.workspaceAgents { - if agent.ID == arg.ID { - agent.LifecycleState = arg.LifecycleState - agent.StartedAt = arg.StartedAt - agent.ReadyAt = arg.ReadyAt - q.workspaceAgents[i] = agent - return nil + var template database.Template + for _, t := range q.templates { + if t.ID == id { + template = t + break } } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceAgentMetadata(_ context.Context, arg database.UpdateWorkspaceAgentMetadataParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - //nolint:gosimple - updated := database.WorkspaceAgentMetadatum{ - WorkspaceAgentID: arg.WorkspaceAgentID, - Key: arg.Key, - Value: arg.Value, - Error: arg.Error, - CollectedAt: arg.CollectedAt, + if template.ID == uuid.Nil { + return nil, sql.ErrNoRows } - for i, m := range q.workspaceAgentMetadata { - if m.WorkspaceAgentID == arg.WorkspaceAgentID && m.Key == arg.Key { - q.workspaceAgentMetadata[i] = updated - return nil + users := make([]database.TemplateUser, 0, len(template.UserACL)) + for k, v := range template.UserACL { + user, err := q.getUserByIDNoLock(uuid.MustParse(k)) + if err != nil && xerrors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get user by ID: %w", err) + } + // We don't delete users from the map if they + // get deleted so just skip. + if xerrors.Is(err, sql.ErrNoRows) { + continue + } + + if user.Deleted || user.Status == database.UserStatusSuspended { + continue } + + users = append(users, database.TemplateUser{ + User: user, + Actions: v, + }) } - return nil + return users, nil } -func (q *FakeQuerier) UpdateWorkspaceAgentStartupByID(_ context.Context, arg database.UpdateWorkspaceAgentStartupByIDParams) error { +//nolint:gocyclo +func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { if err := validateDatabaseType(arg); err != nil { - return err + return nil, err } - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() - for index, agent := range q.workspaceAgents { - if agent.ID != arg.ID { - continue + if prepared != nil { + // Call this to match the same function calls as the SQL implementation. + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + if err != nil { + return nil, err } - - agent.Version = arg.Version - agent.ExpandedDirectory = arg.ExpandedDirectory - agent.Subsystem = arg.Subsystem - q.workspaceAgents[index] = agent - return nil } - return sql.ErrNoRows -} -func (q *FakeQuerier) UpdateWorkspaceAgentStartupLogOverflowByID(_ context.Context, arg database.UpdateWorkspaceAgentStartupLogOverflowByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } + workspaces := make([]database.Workspace, 0) + for _, workspace := range q.workspaces { + if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { + continue + } - q.mutex.Lock() - defer q.mutex.Unlock() - for i, agent := range q.workspaceAgents { - if agent.ID == arg.ID { - agent.StartupLogsOverflowed = arg.StartupLogsOverflowed - q.workspaceAgents[i] = agent - return nil + if arg.OwnerUsername != "" { + owner, err := q.getUserByIDNoLock(workspace.OwnerID) + if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { + continue + } } - } - return sql.ErrNoRows -} -func (q *FakeQuerier) UpdateWorkspaceAppHealthByID(_ context.Context, arg database.UpdateWorkspaceAppHealthByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } + if arg.TemplateName != "" { + template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) + if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) { + continue + } + } - q.mutex.Lock() - defer q.mutex.Unlock() + if !arg.Deleted && workspace.Deleted { + continue + } - for index, app := range q.workspaceApps { - if app.ID != arg.ID { + if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) { continue } - app.Health = arg.Health - q.workspaceApps[index] = app - return nil - } - return sql.ErrNoRows -} -func (q *FakeQuerier) UpdateWorkspaceAutostart(_ context.Context, arg database.UpdateWorkspaceAutostartParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } + if arg.Status != "" { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return nil, xerrors.Errorf("get latest build: %w", err) + } - q.mutex.Lock() - defer q.mutex.Unlock() + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue - } - workspace.AutostartSchedule = arg.AutostartSchedule - q.workspaces[index] = workspace - return nil - } + // This logic should match the logic in the workspace.sql file. + var statusMatch bool + switch database.WorkspaceStatus(arg.Status) { + case database.WorkspaceStatusPending: + statusMatch = isNull(job.StartedAt) + case database.WorkspaceStatusStarting: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionStart - return sql.ErrNoRows -} + case database.WorkspaceStatusRunning: + statusMatch = isNotNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionStart -func (q *FakeQuerier) UpdateWorkspaceBuildByID(_ context.Context, arg database.UpdateWorkspaceBuildByIDParams) (database.WorkspaceBuild, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceBuild{}, err - } + case database.WorkspaceStatusStopping: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionStop - q.mutex.Lock() - defer q.mutex.Unlock() + case database.WorkspaceStatusStopped: + statusMatch = isNotNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionStop + case database.WorkspaceStatusFailed: + statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) || + (isNotNull(job.CompletedAt) && isNotNull(job.Error)) - for index, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.ID != arg.ID { - continue - } - workspaceBuild.UpdatedAt = arg.UpdatedAt - workspaceBuild.ProvisionerState = arg.ProvisionerState - workspaceBuild.Deadline = arg.Deadline - workspaceBuild.MaxDeadline = arg.MaxDeadline - q.workspaceBuilds[index] = workspaceBuild - return workspaceBuild, nil - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} + case database.WorkspaceStatusCanceling: + statusMatch = isNotNull(job.CanceledAt) && + isNull(job.CompletedAt) -func (q *FakeQuerier) UpdateWorkspaceBuildCostByID(_ context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) (database.WorkspaceBuild, error) { - if err := validateDatabaseType(arg); err != nil { - return database.WorkspaceBuild{}, err - } + case database.WorkspaceStatusCanceled: + statusMatch = isNotNull(job.CanceledAt) && + isNotNull(job.CompletedAt) - q.mutex.Lock() - defer q.mutex.Unlock() + case database.WorkspaceStatusDeleted: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNotNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionDelete && + isNull(job.Error) - for index, workspaceBuild := range q.workspaceBuilds { - if workspaceBuild.ID != arg.ID { - continue + case database.WorkspaceStatusDeleting: + statusMatch = isNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionDelete + + default: + return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status) + } + if !statusMatch { + continue + } } - workspaceBuild.DailyCost = arg.DailyCost - q.workspaceBuilds[index] = workspaceBuild - return workspaceBuild, nil - } - return database.WorkspaceBuild{}, sql.ErrNoRows -} -func (q *FakeQuerier) UpdateWorkspaceDeletedByID(_ context.Context, arg database.UpdateWorkspaceDeletedByIDParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } + if arg.HasAgent != "" { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return nil, xerrors.Errorf("get latest build: %w", err) + } - q.mutex.Lock() - defer q.mutex.Unlock() + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { + workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) + if err != nil { + return nil, xerrors.Errorf("get workspace resources: %w", err) + } + + var workspaceResourceIDs []uuid.UUID + for _, wr := range workspaceResources { + workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) + } + + workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) + if err != nil { + return nil, xerrors.Errorf("get workspace agents: %w", err) + } + + var hasAgentMatched bool + for _, wa := range workspaceAgents { + if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent { + hasAgentMatched = true + } + } + + if !hasAgentMatched { + continue + } + } + + if len(arg.TemplateIds) > 0 { + match := false + for _, id := range arg.TemplateIds { + if workspace.TemplateID == id { + match = true + break + } + } + if !match { + continue + } + } + + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { continue } - workspace.Deleted = arg.Deleted - q.workspaces[index] = workspace - return nil + workspaces = append(workspaces, workspace) } - return sql.ErrNoRows -} -func (q *FakeQuerier) UpdateWorkspaceLastUsedAt(_ context.Context, arg database.UpdateWorkspaceLastUsedAtParams) error { - if err := validateDatabaseType(arg); err != nil { - return err + // Sort workspaces (ORDER BY) + isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool { + return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart } - q.mutex.Lock() - defer q.mutex.Unlock() + preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{} + preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{} + preloadedUsers := map[uuid.UUID]database.User{} - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue + for _, w := range workspaces { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) + if err == nil { + preloadedWorkspaceBuilds[w.ID] = build + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get latest build: %w", err) } - workspace.LastUsedAt = arg.LastUsedAt - q.workspaces[index] = workspace - return nil - } - return sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceLockedAt(_ context.Context, arg database.UpdateWorkspaceLockedAtParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } - - q.mutex.Lock() - defer q.mutex.Unlock() + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err == nil { + preloadedProvisionerJobs[w.ID] = job + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue + user, err := q.getUserByIDNoLock(w.OwnerID) + if err == nil { + preloadedUsers[w.ID] = user + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get user: %w", err) } - workspace.LockedAt = arg.LockedAt - workspace.LastUsedAt = database.Now() - q.workspaces[index] = workspace - return nil } - return sql.ErrNoRows -} + sort.Slice(workspaces, func(i, j int) bool { + w1 := workspaces[i] + w2 := workspaces[j] -func (q *FakeQuerier) UpdateWorkspaceProxy(_ context.Context, arg database.UpdateWorkspaceProxyParams) (database.WorkspaceProxy, error) { - q.mutex.Lock() - defer q.mutex.Unlock() + // Order by: running first + w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID]) + w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID]) - for _, p := range q.workspaceProxies { - if p.Name == arg.Name && p.ID != arg.ID { - return database.WorkspaceProxy{}, errDuplicateKey + if w1IsRunning && !w2IsRunning { + return true } - } - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Name = arg.Name - p.DisplayName = arg.DisplayName - p.Icon = arg.Icon - if len(p.TokenHashedSecret) > 0 { - p.TokenHashedSecret = arg.TokenHashedSecret - } - q.workspaceProxies[i] = p - return p, nil + if !w1IsRunning && w2IsRunning { + return false } - } - return database.WorkspaceProxy{}, sql.ErrNoRows -} - -func (q *FakeQuerier) UpdateWorkspaceProxyDeleted(_ context.Context, arg database.UpdateWorkspaceProxyDeletedParams) error { - q.mutex.Lock() - defer q.mutex.Unlock() - for i, p := range q.workspaceProxies { - if p.ID == arg.ID { - p.Deleted = arg.Deleted - p.UpdatedAt = database.Now() - q.workspaceProxies[i] = p - return nil + // Order by: usernames + if w1.ID != w2.ID { + return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username}) } - } - return sql.ErrNoRows -} -func (q *FakeQuerier) UpdateWorkspaceTTL(_ context.Context, arg database.UpdateWorkspaceTTLParams) error { - if err := validateDatabaseType(arg); err != nil { - return err - } + // Order by: workspace names + return sort.StringsAreSorted([]string{w1.Name, w2.Name}) + }) - q.mutex.Lock() - defer q.mutex.Unlock() + beforePageCount := len(workspaces) - for index, workspace := range q.workspaces { - if workspace.ID != arg.ID { - continue + if arg.Offset > 0 { + if int(arg.Offset) > len(workspaces) { + return []database.GetWorkspacesRow{}, nil } - workspace.Ttl = arg.Ttl - q.workspaces[index] = workspace - return nil + workspaces = workspaces[arg.Offset:] + } + if arg.Limit > 0 { + if int(arg.Limit) > len(workspaces) { + return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil + } + workspaces = workspaces[:arg.Limit] } - return sql.ErrNoRows + return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil } -func (q *FakeQuerier) UpdateWorkspaceTTLToBeWithinTemplateMax(_ context.Context, arg database.UpdateWorkspaceTTLToBeWithinTemplateMaxParams) error { - if err := validateDatabaseType(arg); err != nil { - return err +func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + if err := validateDatabaseType(params); err != nil { + return 0, err } - q.mutex.Lock() - defer q.mutex.Unlock() + q.mutex.RLock() + defer q.mutex.RUnlock() - for index, workspace := range q.workspaces { - if workspace.TemplateID != arg.TemplateID || !workspace.Ttl.Valid || workspace.Ttl.Int64 < arg.TemplateMaxTTL { - continue + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + if err != nil { + return -1, err } - - workspace.Ttl = sql.NullInt64{Int64: arg.TemplateMaxTTL, Valid: true} - q.workspaces[index] = workspace } - return nil -} - -func (q *FakeQuerier) UpsertAppSecurityKey(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.appSecurityKey = data - return nil -} - -func (q *FakeQuerier) UpsertDefaultProxy(_ context.Context, arg database.UpsertDefaultProxyParams) error { - q.defaultProxyDisplayName = arg.DisplayName - q.defaultProxyIconURL = arg.IconUrl - return nil -} - -func (q *FakeQuerier) UpsertLastUpdateCheck(_ context.Context, data string) error { - q.mutex.Lock() - defer q.mutex.Unlock() - - q.lastUpdateCheck = []byte(data) - return nil -} - -func (q *FakeQuerier) UpsertLogoURL(_ context.Context, data string) error { - q.mutex.RLock() - defer q.mutex.RUnlock() + users := make([]database.User, 0, len(q.users)) - q.logoURL = data - return nil -} + for _, user := range q.users { + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { + continue + } -func (q *FakeQuerier) UpsertOAuthSigningKey(_ context.Context, value string) error { - q.mutex.Lock() - defer q.mutex.Unlock() + users = append(users, user) + } - q.oauthSigningKey = value - return nil -} + // Filter out deleted since they should never be returned.. + tmp := make([]database.User, 0, len(users)) + for _, user := range users { + if !user.Deleted { + tmp = append(tmp, user) + } + } + users = tmp -func (q *FakeQuerier) UpsertServiceBanner(_ context.Context, data string) error { - q.mutex.RLock() - defer q.mutex.RUnlock() + if params.Search != "" { + tmp := make([]database.User, 0, len(users)) + for i, user := range users { + if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } + } + users = tmp + } - q.serviceBanner = []byte(data) - return nil -} + if len(params.Status) > 0 { + usersFilteredByStatus := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { + return strings.EqualFold(string(a), string(b)) + }) { + usersFilteredByStatus = append(usersFilteredByStatus, users[i]) + } + } + users = usersFilteredByStatus + } -func (*FakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { - return database.TailnetAgent{}, ErrUnimplemented -} + if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { + usersFilteredByRole := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { + usersFilteredByRole = append(usersFilteredByRole, users[i]) + } + } -func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { - return database.TailnetClient{}, ErrUnimplemented -} + users = usersFilteredByRole + } -func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { - return database.TailnetCoordinator{}, ErrUnimplemented + return int64(len(users)), nil } diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index ec28fd428a102..11d857a30262b 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -16,6 +16,12 @@ import ( "github.com/coder/coder/coderd/rbac" ) +var ( + // Force these imports, for some reason the autogen does not include them. + _ uuid.UUID + _ rbac.Action +) + const wrapname = "dbmetrics.metricsStore" // New returns a database.Store that registers metrics for all queries to reg. @@ -73,41 +79,6 @@ func (m metricsStore) InTx(f func(database.Store) error, options *sql.TxOptions) return err } -func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { - start := time.Now() - templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds()) - return templates, err -} - -func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - start := time.Now() - roles, err := m.s.GetTemplateGroupRoles(ctx, id) - m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds()) - return roles, err -} - -func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - start := time.Now() - roles, err := m.s.GetTemplateUserRoles(ctx, id) - m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds()) - return roles, err -} - -func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - start := time.Now() - workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds()) - return workspaces, err -} - -func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - start := time.Now() - count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds()) - return count, err -} - func (m metricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error { start := time.Now() err := m.s.AcquireLock(ctx, pgAdvisoryXactLock) @@ -1639,3 +1610,38 @@ func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) return m.s.UpsertTailnetCoordinator(ctx, id) } + +func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { + start := time.Now() + templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds()) + return templates, err +} + +func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + start := time.Now() + roles, err := m.s.GetTemplateGroupRoles(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds()) + return roles, err +} + +func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + start := time.Now() + roles, err := m.s.GetTemplateUserRoles(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds()) + return roles, err +} + +func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + start := time.Now() + workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds()) + return workspaces, err +} + +func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + start := time.Now() + count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds()) + return count, err +} diff --git a/scripts/dbgen/main.go b/scripts/dbgen/main.go index 6141980428cef..0eeda09ba966d 100644 --- a/scripts/dbgen/main.go +++ b/scripts/dbgen/main.go @@ -418,21 +418,44 @@ type querierFunction struct { // readQuerierFunctions reads the functions from coderd/database/querier.go func readQuerierFunctions() ([]querierFunction, error) { + f, err := parseDBFile("querier.go") + if err != nil { + return nil, xerrors.Errorf("parse querier.go: %w", err) + } + funcs, err := loadInterfaceFuncs(f, "sqlcQuerier") + if err != nil { + return nil, xerrors.Errorf("load interface %s funcs: %w", "sqlcQuerier", err) + } + + customFile, err := parseDBFile("modelqueries.go") + if err != nil { + return nil, xerrors.Errorf("parse modelqueriers.go: %w", err) + } + // Custom funcs should be appended after the regular functions + customFuncs, err := loadInterfaceFuncs(customFile, "customQuerier") + if err != nil { + return nil, xerrors.Errorf("load interface %s funcs: %w", "customQuerier", err) + } + + return append(funcs, customFuncs...), nil +} + +func parseDBFile(filename string) (*dst.File, error) { localPath, err := localFilePath() if err != nil { return nil, err } - querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", "querier.go") + querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", filename) querierData, err := os.ReadFile(querierPath) if err != nil { - return nil, xerrors.Errorf("read querier: %w", err) + return nil, xerrors.Errorf("read %s: %w", filename, err) } f, err := decorator.Parse(querierData) - if err != nil { - return nil, err - } + return f, err +} +func loadInterfaceFuncs(f *dst.File, interfaceName string) ([]querierFunction, error) { var querier *dst.InterfaceType for _, decl := range f.Decls { genDecl, ok := decl.(*dst.GenDecl) @@ -447,7 +470,7 @@ func readQuerierFunctions() ([]querierFunction, error) { } // This is the name of the interface. If that ever changes, // this will need to be updated. - if typeSpec.Name.Name != "sqlcQuerier" { + if typeSpec.Name.Name != interfaceName { continue } querier, ok = typeSpec.Type.(*dst.InterfaceType) @@ -461,7 +484,8 @@ func readQuerierFunctions() ([]querierFunction, error) { return nil, xerrors.Errorf("querier not found") } funcs := []querierFunction{} - for _, method := range querier.Methods.List { + allMethods := interfaceMethods(querier) + for _, method := range allMethods { funcType, ok := method.Type.(*dst.FuncType) if !ok { continue @@ -540,3 +564,30 @@ func nameFromSnakeCase(s string) string { } return ret } + +// interfaceMethods returns all embedded methods of an interface. +func interfaceMethods(i *dst.InterfaceType) []*dst.Field { + var allMethods []*dst.Field + for _, field := range i.Methods.List { + switch fieldType := field.Type.(type) { + case *dst.FuncType: + allMethods = append(allMethods, field) + case *dst.InterfaceType: + allMethods = append(allMethods, interfaceMethods(fieldType)...) + case *dst.Ident: + // Embedded interfaces are Idents -> TypeSpec -> InterfaceType + // If the embedded interface is not in the parsed file, then + // the Obj will be nil. + if fieldType.Obj != nil { + objDecl, ok := fieldType.Obj.Decl.(*dst.TypeSpec) + if ok { + isInterface, ok := objDecl.Type.(*dst.InterfaceType) + if ok { + allMethods = append(allMethods, interfaceMethods(isInterface)...) + } + } + } + } + } + return allMethods +}